From 05398a0de4d5454fb329ac5c8585e11a72c7d86e Mon Sep 17 00:00:00 2001 From: Sanghyun Date: Mon, 24 Nov 2025 21:43:45 -0800 Subject: [PATCH 01/12] implemented diff rigid code, with fixes based on comments --- .../entities/rigid_entity/rigid_entity.py | 198 +- genesis/engine/simulator.py | 17 +- .../engine/solvers/rigid/constraint_noslip.py | 6 + .../solvers/rigid/constraint_solver_decomp.py | 2 + .../rigid/constraint_solver_decomp_island.py | 2 +- .../solvers/rigid/rigid_solver_decomp.py | 4127 +++++++++++------ genesis/engine/states/__init__.py | 2 + genesis/engine/states/entities.py | 38 + genesis/engine/states/solvers.py | 9 +- genesis/utils/array_class.py | 196 +- genesis/utils/geom.py | 8 +- genesis/utils/path_planning.py | 4 + tests/test_grad.py | 93 +- 13 files changed, 3300 insertions(+), 1402 deletions(-) diff --git a/genesis/engine/entities/rigid_entity/rigid_entity.py b/genesis/engine/entities/rigid_entity/rigid_entity.py index 2d35fe665..fd9ea8a48 100644 --- a/genesis/engine/entities/rigid_entity/rigid_entity.py +++ b/genesis/engine/entities/rigid_entity/rigid_entity.py @@ -18,7 +18,8 @@ from genesis.utils import mjcf as mju from genesis.utils import terrain as tu from genesis.utils import urdf as uu -from genesis.utils.misc import ALLOCATE_TENSOR_WARNING, DeprecationError, ti_to_torch +from genesis.utils.misc import ALLOCATE_TENSOR_WARNING, DeprecationError, ti_to_torch, to_gs_tensor +from genesis.engine.states.entities import RigidEntityState from ..base_entity import Entity from .rigid_equality import RigidEquality @@ -97,6 +98,23 @@ def __init__( self._load_model() + # Initialize target variables and checkpoint + self._tgt_keys = ["pos", "quat", "qpos", "dofs_velocity"] + self._tgt = dict() + self._tgt_buffer = list() + self._ckpt = dict() + self._update_tgt_while_set = True + + def _update_tgt(self, key, value): + # Set [self._tgt] value while keeping the insertion order between keys. When a new key is inserted or an existing + # key is updated, the new element should be inserted at the end of the dict. This is because we need to keep + # the insertion order to correctly pass the gradients in the backward pass. + self._tgt.pop(key, None) + self._tgt[key] = value + + def init_ckpt(self): + pass + def _load_model(self): self._links = gs.List() self._joints = gs.List() @@ -1459,6 +1477,7 @@ def _kernel_forward_kinematics( entities_info, rigid_global_info, static_rigid_sim_config, + is_backward=False, ) ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) @@ -1488,6 +1507,7 @@ def _kernel_forward_kinematics( entities_info, rigid_global_info, static_rigid_sim_config, + is_backward=False, ) # ------------------------------------------------------------------------------------ @@ -1622,6 +1642,105 @@ def plan_path( # ------------------------------------------------------------------------------------ # ---------------------------------- control & io ------------------------------------ # ------------------------------------------------------------------------------------ + def process_input(self, in_backward=False): + if in_backward: + # use negative index because buffer length might not be full + index = self._sim.cur_step_local - self._sim._steps_local + self._tgt = self._tgt_buffer[index].copy() + else: + self._tgt_buffer.append(self._tgt.copy()) + + # Apply targets in the order of insertion + for key in self._tgt.keys(): + data_kwargs = self._tgt[key] + + # We do not need zero velocity here because if it was true, [set_dofs_velocity] from zero_velocity would + # be in [tgt] + if "zero_velocity" in data_kwargs: + data_kwargs["zero_velocity"] = False + # Do not update [tgt], as input information is finalized at this point + self._update_tgt_while_set = False + + match key: + case "pos": + self.set_pos(**data_kwargs) + case "quat": + self.set_quat(**data_kwargs) + case "qpos": + self.set_qpos(**data_kwargs) + case "dofs_velocity": + self.set_dofs_velocity(**data_kwargs) + case _: + gs.raise_exception(f"Invalid target key: {key} not in {self._tgt_keys}") + + self._tgt = dict() + self._update_tgt_while_set = True + + def process_input_grad(self): + index = self._sim.cur_step_local - self._sim._steps_local + for key in reversed(self._tgt_buffer[index].keys()): + data_kwargs = self._tgt_buffer[index][key] + + match key: + # We need to unpack the data_kwargs because [_backward_from_ti] only supports positional arguments + case "pos": + pos = data_kwargs.pop("pos") + if pos.requires_grad: + pos._backward_from_ti( + self.set_pos_grad, data_kwargs["envs_idx"], data_kwargs["relative"], data_kwargs["unsafe"] + ) + + case "quat": + quat = data_kwargs.pop("quat") + if quat.requires_grad: + quat._backward_from_ti( + self.set_quat_grad, data_kwargs["envs_idx"], data_kwargs["relative"], data_kwargs["unsafe"] + ) + + case "qpos": + qpos = data_kwargs.pop("qpos") + if qpos.requires_grad: + raise NotImplementedError("Backward pass for set_qpos_grad is not implemented yet.") + + case "dofs_velocity": + velocity = data_kwargs.pop("velocity") + # [velocity] could be None when we want to zero the velocity (see set_dofs_velocity of RigidSolver) + if velocity is not None and velocity.requires_grad: + velocity._backward_from_ti( + self.set_dofs_velocity_grad, + data_kwargs["dofs_idx_local"], + data_kwargs["envs_idx"], + data_kwargs["unsafe"], + ) + case _: + gs.raise_exception(f"Invalid target key: {key} not in {self._tgt_keys}") + + def save_ckpt(self, ckpt_name): + if ckpt_name not in self._ckpt: + self._ckpt[ckpt_name] = {} + self._ckpt[ckpt_name]["_tgt_buffer"] = self._tgt_buffer.copy() + self._tgt_buffer.clear() + + def load_ckpt(self, ckpt_name): + self._tgt_buffer = self._ckpt[ckpt_name]["_tgt_buffer"].copy() + + def reset_grad(self): + self._tgt_buffer.clear() + + @gs.assert_built + def get_state(self): + state = RigidEntityState(self, self._sim.cur_step_global) + + solver_state = self._solver.get_state() + pos = solver_state.links_pos[:, self.base_link_idx] + quat = solver_state.links_quat[:, self.base_link_idx] + + assert state._pos.shape == pos.shape + assert state._quat.shape == quat.shape + state._pos = pos + state._quat = quat + + return state def get_joint(self, name=None, uid=None): """ @@ -1966,6 +2085,19 @@ def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, uns envs_idx : None | array_like, optional The indices of the environments. If None, all environments will be considered. Defaults to None. """ + # Save in [tgt] for backward pass + if self._update_tgt_while_set: + self._update_tgt( + "pos", + { + "pos": pos, + "envs_idx": envs_idx, + "relative": relative, + "zero_velocity": zero_velocity, + "unsafe": unsafe, + }, + ) + if not unsafe: _pos = torch.as_tensor(pos, dtype=gs.tc_float, device=gs.device).contiguous() if _pos is not pos: @@ -1982,6 +2114,16 @@ def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, uns if zero_velocity: self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe) + @gs.assert_built + def set_pos_grad(self, envs_idx, relative, unsafe, pos_grad): + self._solver.set_base_links_pos_grad( + self._base_links_idx_, + envs_idx, + relative, + unsafe, + pos_grad.data, + ) + @gs.assert_built def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, unsafe=False): """ @@ -2000,6 +2142,18 @@ def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, u envs_idx : None | array_like, optional The indices of the environments. If None, all environments will be considered. Defaults to None. """ + # Save in [tgt] for backward pass + if self._update_tgt_while_set: + self._update_tgt( + "quat", + { + "quat": quat, + "envs_idx": envs_idx, + "relative": relative, + "zero_velocity": zero_velocity, + "unsafe": unsafe, + }, + ) if not unsafe: _quat = torch.as_tensor(quat, dtype=gs.tc_float, device=gs.device).contiguous() if _quat is not quat: @@ -2016,6 +2170,16 @@ def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, u if zero_velocity: self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe) + @gs.assert_built + def set_quat_grad(self, envs_idx, relative, unsafe, quat_grad): + self._solver.set_base_links_quat_grad( + self._base_links_idx_, + envs_idx, + relative, + unsafe, + quat_grad.data, + ) + @gs.assert_built def get_verts(self): """ @@ -2121,6 +2285,19 @@ def set_qpos(self, qpos, qs_idx_local=None, envs_idx=None, *, zero_velocity=True zero_velocity : bool, optional Whether to zero the velocity of all the entity's dofs. Defaults to True. This is a safety measure after a sudden change in entity pose. """ + # Save in [tgt] for backward pass + if self._update_tgt_while_set: + self._update_tgt( + "qpos", + { + "qpos": qpos, + "qs_idx_local": qs_idx_local, + "envs_idx": envs_idx, + "zero_velocity": zero_velocity, + "unsafe": unsafe, + }, + ) + qs_idx = self._get_idx(qs_idx_local, self.n_qs, self._q_start, unsafe=True) self._solver.set_qpos(qpos, qs_idx, envs_idx, unsafe=unsafe, skip_forward=zero_velocity) if zero_velocity: @@ -2216,9 +2393,25 @@ def set_dofs_velocity(self, velocity=None, dofs_idx_local=None, envs_idx=None, * envs_idx : None | array_like, optional The indices of the environments. If None, all environments will be considered. Defaults to None. """ + # Save in [tgt] for backward pass + if self._update_tgt_while_set: + self._update_tgt( + "dofs_velocity", + { + "velocity": velocity, + "dofs_idx_local": dofs_idx_local, + "envs_idx": envs_idx, + "unsafe": unsafe, + }, + ) dofs_idx = self._get_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True) self._solver.set_dofs_velocity(velocity, dofs_idx, envs_idx, skip_forward=False, unsafe=unsafe) + @gs.assert_built + def set_dofs_velocity_grad(self, dofs_idx_local, envs_idx, unsafe, velocity_grad): + dofs_idx = self._get_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True) + self._solver.set_dofs_velocity_grad(dofs_idx, envs_idx, unsafe, velocity_grad.data) + @gs.assert_built def set_dofs_frictionloss(self, frictionloss, dofs_idx_local=None, envs_idx=None, *, unsafe=False): """ @@ -3176,6 +3369,7 @@ def kernel_rigid_entity_inverse_kinematics( entities_info, rigid_global_info, static_rigid_sim_config, + False, ) # compute error solved = True @@ -3304,6 +3498,7 @@ def kernel_rigid_entity_inverse_kinematics( entities_info, rigid_global_info, static_rigid_sim_config, + False, ) solved = True for i_ee in range(n_links): @@ -3403,4 +3598,5 @@ def kernel_rigid_entity_inverse_kinematics( entities_info, rigid_global_info, static_rigid_sim_config, + False, ) diff --git a/genesis/engine/simulator.py b/genesis/engine/simulator.py index ef2093e11..488e7febf 100644 --- a/genesis/engine/simulator.py +++ b/genesis/engine/simulator.py @@ -275,18 +275,13 @@ def f_global_to_s_global(self, f_global): # ------------------------------------------------------------------------------------ def step(self, in_backward=False): - if self._rigid_only: # "Only Advance!" --Thomas Wade :P - for _ in range(self._substeps): - self.rigid_solver.substep() - self._cur_substep_global += 1 - else: - self.process_input(in_backward=in_backward) - for _ in range(self._substeps): - self.substep(self.cur_substep_local) + self.process_input(in_backward=in_backward) + for _ in range(self._substeps): + self.substep(self.cur_substep_local) - self._cur_substep_global += 1 - if self.cur_substep_local == 0 and not in_backward: - self.save_ckpt() + self._cur_substep_global += 1 + if self.cur_substep_local == 0 and not in_backward: + self.save_ckpt() if self.rigid_solver.is_active: self.rigid_solver.clear_external_force() diff --git a/genesis/engine/solvers/rigid/constraint_noslip.py b/genesis/engine/solvers/rigid/constraint_noslip.py index 69e84500f..8e02a545c 100644 --- a/genesis/engine/solvers/rigid/constraint_noslip.py +++ b/genesis/engine/solvers/rigid/constraint_noslip.py @@ -35,10 +35,12 @@ def kernel_build_efc_AR_b( rigid_solver.func_solve_mass_batched( constraint_state.Mgrad, constraint_state.Mgrad, + array_class.PLACEHOLDER, i_b, entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=False, ) # AR[r, c] = J[c, :] * tmp @@ -191,10 +193,12 @@ def kernel_dual_finish( rigid_solver.func_solve_mass_batched( vec=constraint_state.qfrc_constraint, out=constraint_state.qacc, + out_bw=array_class.PLACEHOLDER, i_b=i_b, entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=False, ) for i_d in range(n_dofs): @@ -283,10 +287,12 @@ def compute_A_diag( rigid_solver.func_solve_mass_batched( constraint_state.Mgrad, constraint_state.Mgrad, + array_class.PLACEHOLDER, i_b, entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=False, ) # Ai = Ji * tmp diff --git a/genesis/engine/solvers/rigid/constraint_solver_decomp.py b/genesis/engine/solvers/rigid/constraint_solver_decomp.py index e89afab9b..a4b0a76e3 100644 --- a/genesis/engine/solvers/rigid/constraint_solver_decomp.py +++ b/genesis/engine/solvers/rigid/constraint_solver_decomp.py @@ -1989,10 +1989,12 @@ def func_update_gradient( rigid_solver.func_solve_mass_batched( constraint_state.grad, constraint_state.Mgrad, + array_class.PLACEHOLDER, i_b, entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=False, ) elif ti.static(static_rigid_sim_config.solver_type == gs.constraint_solver.Newton): diff --git a/genesis/engine/solvers/rigid/constraint_solver_decomp_island.py b/genesis/engine/solvers/rigid/constraint_solver_decomp_island.py index 70e2ce43b..9415a98a2 100644 --- a/genesis/engine/solvers/rigid/constraint_solver_decomp_island.py +++ b/genesis/engine/solvers/rigid/constraint_solver_decomp_island.py @@ -990,7 +990,7 @@ def _func_update_gradient(self, island, i_b): i_e_ = self.contact_island.island_entity[island, i_b].start + i_island_entity i_e = self.contact_island.entity_id[i_e_, i_b] self._solver.mass_mat_mask[i_e_, i_b] = True - self._solver._func_solve_mass_batched(self.grad, self.Mgrad, i_b) + self._solver._func_solve_mass_batched(self.grad, self.Mgrad, array_class.PLACEHOLDER, i_b) for i_e in range(self._solver.n_entities): self._solver.mass_mat_mask[i_e, i_b] = True elif ti.static(self._solver_type == gs.constraint_solver.Newton): diff --git a/genesis/engine/solvers/rigid/rigid_solver_decomp.py b/genesis/engine/solvers/rigid/rigid_solver_decomp.py index f2274ab09..4dc5aaedf 100644 --- a/genesis/engine/solvers/rigid/rigid_solver_decomp.py +++ b/genesis/engine/solvers/rigid/rigid_solver_decomp.py @@ -11,7 +11,7 @@ import genesis.utils.geom as gu from genesis.engine.entities import AvatarEntity, DroneEntity, RigidEntity from genesis.engine.entities.base_entity import Entity -from genesis.engine.states.solvers import RigidSolverState +from genesis.engine.states import QueriedStates, RigidSolverState from genesis.options.solvers import RigidOptions from genesis.utils import linalg as lu from genesis.utils.misc import ( @@ -36,7 +36,6 @@ from genesis.engine.scene import Scene from genesis.engine.simulator import Simulator - # minimum constraint impedance IMP_MIN = 0.0001 # maximum constraint impedance @@ -135,6 +134,13 @@ def __init__(self, scene: "Scene", sim: "Simulator", options: RigidOptions) -> N self.qpos: ti.Template | ti.types.NDArray | None = None + self._queried_states = QueriedStates() + + self._ckpt = dict() + + def init_ckpt(self): + pass + def add_entity(self, idx, material, morph, surface, visualize_contact) -> Entity: if isinstance(material, gs.materials.Avatar): EntityClass = AvatarEntity @@ -202,6 +208,14 @@ def build(self): self._n_entities = self.n_entities self._n_equalities = self.n_equalities + self._max_n_links_per_entity = self.max_n_links_per_entity + self._max_n_joints_per_link = self.max_n_joints_per_link + self._max_n_dofs_per_joint = self.max_n_dofs_per_joint + self._max_n_qs_per_link = self.max_n_qs_per_link + self._max_n_dofs_per_entity = self.max_n_dofs_per_entity + self._max_n_dofs_per_link = self.max_n_dofs_per_link + self._max_n_geoms_per_entity = self.max_n_geoms_per_entity + self._geoms = self.geoms self._vgeoms = self.vgeoms self._links = self.links @@ -273,6 +287,55 @@ def build(self): solver_type=gs.constraint_solver.CG, ) + if self._static_rigid_sim_config.requires_grad: + if self._static_rigid_sim_config.use_hibernation: + gs.raise_exception("Hibernation is not supported yet when requires_grad is True") + if self._static_rigid_sim_config.integrator != gs.integrator.approximate_implicitfast: + gs.raise_exception( + "Only approximate_implicitfast integrator is supported yet when requires_grad is True." + ) + from genesis.engine.couplers import SAPCoupler, IPCCoupler + + if isinstance(self.sim.coupler, (SAPCoupler, IPCCoupler)): + gs.raise_exception( + f"{type(self.sim.coupler).__name__} is not supported yet when requires_grad is True." + ) + + if getattr(self._options, "noslip_iterations", 0) > 0: + gs.raise_exception("Noslip is not supported yet when requires_grad is True.") + + # Add terms for static inner loops, use 0 if not requires_grad to avoid re-compilation + self._static_rigid_sim_config.max_n_links_per_entity = ( + self._max_n_links_per_entity if self._static_rigid_sim_config.requires_grad else 0 + ) + self._static_rigid_sim_config.max_n_joints_per_link = ( + self._max_n_joints_per_link if self._static_rigid_sim_config.requires_grad else 0 + ) + self._static_rigid_sim_config.max_n_dofs_per_joint = ( + self._max_n_dofs_per_joint if self._static_rigid_sim_config.requires_grad else 0 + ) + self._static_rigid_sim_config.max_n_qs_per_link = ( + self._max_n_qs_per_link if self._static_rigid_sim_config.requires_grad else 0 + ) + self._static_rigid_sim_config.max_n_dofs_per_entity = ( + self._max_n_dofs_per_entity if self._static_rigid_sim_config.requires_grad else 0 + ) + self._static_rigid_sim_config.max_n_dofs_per_link = ( + self._max_n_dofs_per_link if self._static_rigid_sim_config.requires_grad else 0 + ) + self._static_rigid_sim_config.max_n_geoms_per_entity = ( + self._max_n_geoms_per_entity if self._static_rigid_sim_config.requires_grad else 0 + ) + self._static_rigid_sim_config.max_n_awake_links = ( + self._n_links if self._static_rigid_sim_config.requires_grad else 0 + ) + self._static_rigid_sim_config.max_n_awake_entities = ( + self._n_entities if self._static_rigid_sim_config.requires_grad else 0 + ) + self._static_rigid_sim_config.max_n_awake_dofs = ( + self._n_dofs if self._static_rigid_sim_config.requires_grad else 0 + ) + # when the migration is finished, we will remove the about two lines self._func_vel_at_point = func_vel_at_point self._func_apply_coupling_force = func_apply_coupling_force @@ -284,6 +347,7 @@ def build(self): self._errno = self.data_manager.errno self._rigid_global_info = self.data_manager.rigid_global_info + self._rigid_adjoint_cache = self.data_manager.rigid_adjoint_cache if self._use_hibernation: self.n_awake_dofs = self._rigid_global_info.n_awake_dofs self.awake_dofs = self._rigid_global_info.awake_dofs @@ -291,6 +355,11 @@ def build(self): self.awake_links = self._rigid_global_info.awake_links self.n_awake_entities = self._rigid_global_info.n_awake_entities self.awake_entities = self._rigid_global_info.awake_entities + if self._requires_grad: + self.dofs_state_adjoint_cache = self.data_manager.dofs_state_adjoint_cache + self.links_state_adjoint_cache = self.data_manager.links_state_adjoint_cache + self.joints_state_adjoint_cache = self.data_manager.joints_state_adjoint_cache + self.geoms_state_adjoint_cache = self.data_manager.geoms_state_adjoint_cache self._init_mass_mat() self._init_dof_fields() @@ -341,6 +410,7 @@ def _init_invweight_and_meaninertia(self, envs_idx=None, *, force_update=True, u rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, decompose=True, + is_backward=False, ) # Define some proxies for convenience @@ -845,13 +915,14 @@ def _get_links_data( return tensor[mask] - def substep(self): + def substep(self, f): # from genesis.utils.tools import create_timer from genesis.engine.couplers import SAPCoupler self._links_state_cache.clear() kernel_step_1( + f=f, links_state=self.links_state, links_info=self.links_info, joints_state=self.joints_state, @@ -863,8 +934,10 @@ def substep(self): entities_state=self.entities_state, entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, + rigid_adjoint_cache=self._rigid_adjoint_cache, static_rigid_sim_config=self._static_rigid_sim_config, contact_island_state=self.constraint_solver.contact_island.contact_island_state, + is_backward=False, ) if isinstance(self.sim.coupler, SAPCoupler): @@ -872,10 +945,12 @@ def substep(self): dofs_state=self.dofs_state, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=False, ) else: self._func_constraint_force() kernel_step_2( + f=f, dofs_state=self.dofs_state, dofs_info=self.dofs_info, links_info=self.links_info, @@ -888,8 +963,10 @@ def substep(self): geoms_state=self.geoms_state, collider_state=self.collider._collider_state, rigid_global_info=self._rigid_global_info, + rigid_adjoint_cache=self._rigid_adjoint_cache, static_rigid_sim_config=self._static_rigid_sim_config, contact_island_state=self.constraint_solver.contact_island.contact_island_state, + is_backward=False, ) def check_errno(self): @@ -958,6 +1035,7 @@ def _func_forward_dynamics(self): rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, contact_island_state=self.constraint_solver.contact_island.contact_island_state, + is_backward=False, ) def _func_update_acc(self): @@ -968,6 +1046,7 @@ def _func_update_acc(self): entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=False, ) def _func_forward_kinematics_entity(self, i_e, envs_idx): @@ -983,6 +1062,7 @@ def _func_forward_kinematics_entity(self, i_e, envs_idx): entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=False, ) def _func_integrate_dq_entity(self, dq, i_e, i_b, respect_joint_limit): @@ -999,7 +1079,7 @@ def _func_integrate_dq_entity(self, dq, i_e, i_b, respect_joint_limit): static_rigid_sim_config=self._static_rigid_sim_config, ) - def _func_update_geoms(self, envs_idx, *, force_update_fixed_geoms=False): + def _func_update_geoms(self, envs_idx, *, force_update_fixed_geoms=False, is_backward=False): kernel_update_geoms( envs_idx, entities_info=self.entities_info, @@ -1009,6 +1089,7 @@ def _func_update_geoms(self, envs_idx, *, force_update_fixed_geoms=False): rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, force_update_fixed_geoms=force_update_fixed_geoms, + is_backward=is_backward, ) def _process_dim(self, tensor, envs_idx=None): @@ -1147,10 +1228,147 @@ def substep_pre_coupling(self, f): return # Run Genesis rigid simulation step - self.substep() + self.substep(f) def substep_pre_coupling_grad(self, f): - pass + # Run forward substep again to restore this step's information, this is needed because we do not store info + # of every substep. + kernel_prepare_backward_substep( + f=f, + links_state=self.links_state, + links_info=self.links_info, + joints_state=self.joints_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + geoms_state=self.geoms_state, + geoms_info=self.geoms_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + dofs_state_adjoint_cache=self.dofs_state_adjoint_cache, + links_state_adjoint_cache=self.links_state_adjoint_cache, + joints_state_adjoint_cache=self.joints_state_adjoint_cache, + geoms_state_adjoint_cache=self.geoms_state_adjoint_cache, + rigid_adjoint_cache=self._rigid_adjoint_cache, + static_rigid_sim_config=self._static_rigid_sim_config, + ) + self.substep(f) + + # =================== Backward substep ====================== + if not self._enable_mujoco_compatibility: + kernel_update_cartesian_space.grad( + links_state=self.links_state, + links_info=self.links_info, + joints_state=self.joints_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + geoms_state=self.geoms_state, + geoms_info=self.geoms_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + force_update_fixed_geoms=False, + is_backward=True, + ) + + is_grad_valid = kernel_begin_backward_substep( + f=f, + links_state=self.links_state, + links_info=self.links_info, + joints_state=self.joints_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + geoms_state=self.geoms_state, + geoms_info=self.geoms_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + dofs_state_adjoint_cache=self.dofs_state_adjoint_cache, + links_state_adjoint_cache=self.links_state_adjoint_cache, + joints_state_adjoint_cache=self.joints_state_adjoint_cache, + geoms_state_adjoint_cache=self.geoms_state_adjoint_cache, + rigid_adjoint_cache=self._rigid_adjoint_cache, + static_rigid_sim_config=self._static_rigid_sim_config, + ) + if not is_grad_valid: + gs.raise_exception(f"Nan grad in qpos or dofs_vel found at step {self._sim.cur_step_global}") + + kernel_step_2.grad( + f=f, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + links_info=self.links_info, + links_state=self.links_state, + joints_info=self.joints_info, + joints_state=self.joints_state, + entities_state=self.entities_state, + entities_info=self.entities_info, + geoms_info=self.geoms_info, + geoms_state=self.geoms_state, + collider_state=self.collider._collider_state, + rigid_global_info=self._rigid_global_info, + rigid_adjoint_cache=self._rigid_adjoint_cache, + static_rigid_sim_config=self._static_rigid_sim_config, + contact_island_state=self.constraint_solver.contact_island.contact_island_state, + is_backward=True, + ) + + # We cannot use [kernel_forward_dynamics.grad] because we read [dofs_state.acc] and overwrite it in the kernel, + # which is prohibited (https://docs.taichi-lang.org/docs/differentiable_programming#global-data-access-rules). + # In [kernel_forward_dynamics], we read [acc] in [func_update_acc] and overwrite it in [kernel_compute_qacc]. + # As [kenrel_compute_qacc] is called at the end of [kernel_forward_dynamics], we first backpropagate through + # [kernel_compute_qacc] and then restore the original [acc] from the adjoint cache. This copy operation + # cannot be merged with [kernel_compute_qacc.grad] because .grad function itself is a standalone kernel. + # We could possibly merge this small kernel later if (1) .grad function is regarded as a function instead of a + # kernel, (2) we add another variable to store the new [acc] from [kernel_compute_qacc] and thus can avoid + # the data access violation. However, both of these require major changes. + kernel_compute_qacc.grad( + dofs_state=self.dofs_state, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=True, + ) + kernel_copy_acc( + f=f, + dofs_state=self.dofs_state, + rigid_adjoint_cache=self._rigid_adjoint_cache, + static_rigid_sim_config=self._static_rigid_sim_config, + ) + + kernel_forward_dynamics_without_qacc.grad( + links_state=self.links_state, + links_info=self.links_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + joints_info=self.joints_info, + entities_state=self.entities_state, + entities_info=self.entities_info, + geoms_state=self.geoms_state, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + contact_island_state=self.constraint_solver.contact_island.contact_island_state, + is_backward=True, + ) + + # If it was the very first substep, we need to backpropagate through the initial update of the cartesian space + if self._enable_mujoco_compatibility or self._sim.cur_substep_global == 0: + kernel_update_cartesian_space.grad( + links_state=self.links_state, + links_info=self.links_info, + joints_state=self.joints_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + geoms_state=self.geoms_state, + geoms_info=self.geoms_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + force_update_fixed_geoms=False, + is_backward=True, + ) def substep_post_coupling(self, f): from genesis.engine.couplers import SAPCoupler, IPCCoupler @@ -1163,8 +1381,10 @@ def substep_post_coupling(self, f): dofs_state=self.dofs_state, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=False, ) kernel_step_2( + f=f, dofs_state=self.dofs_state, dofs_info=self.dofs_info, links_info=self.links_info, @@ -1177,8 +1397,10 @@ def substep_post_coupling(self, f): geoms_state=self.geoms_state, collider_state=self.collider._collider_state, rigid_global_info=self._rigid_global_info, + rigid_adjoint_cache=self._rigid_adjoint_cache, static_rigid_sim_config=self._static_rigid_sim_config, contact_island_state=self.constraint_solver.contact_island.contact_island_state, + is_backward=False, ) elif isinstance(self.sim.coupler, IPCCoupler): # For IPCCoupler, perform full rigid body computation in post-coupling phase @@ -1187,25 +1409,56 @@ def substep_post_coupling(self, f): if self.sim.coupler.options.disable_genesis_ground_contact: original_enable_collision = self._enable_collision self._enable_collision = False - self.substep() + self.substep(f) self._enable_collision = original_enable_collision else: - self.substep() + self.substep(f) def substep_post_coupling_grad(self, f): pass def add_grad_from_state(self, state): - pass + if self.is_active: + qpos_grad = gs.zeros_like(state.qpos) + dofs_vel_grad = gs.zeros_like(state.dofs_vel) + links_pos_grad = gs.zeros_like(state.links_pos) + links_quat_grad = gs.zeros_like(state.links_quat) + + if state.qpos.grad is not None: + qpos_grad = state.qpos.grad + if state.dofs_vel.grad is not None: + dofs_vel_grad = state.dofs_vel.grad + if state.links_pos.grad is not None: + links_pos_grad = state.links_pos.grad + if state.links_quat.grad is not None: + links_quat_grad = state.links_quat.grad + + kernel_get_state_grad( + qpos_grad=qpos_grad, + vel_grad=dofs_vel_grad, + links_pos_grad=links_pos_grad, + links_quat_grad=links_quat_grad, + links_state=self.links_state, + dofs_state=self.dofs_state, + geoms_state=self.geoms_state, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) def collect_output_grads(self): """ Collect gradients from downstream queried states. """ - pass + if self._sim.cur_step_global in self._queried_states: + # one step could have multiple states + assert len(self._queried_states[self._sim.cur_step_global]) == 1 + state = self._queried_states[self._sim.cur_step_global][0] + self.add_grad_from_state(state) def reset_grad(self): - pass + for entity in self._entities: + entity.reset_grad() + self._queried_states.clear() def update_geoms_render_T(self): kernel_update_geoms_render_T( @@ -1225,9 +1478,13 @@ def update_vgeoms_render_T(self): static_rigid_sim_config=self._static_rigid_sim_config, ) - def get_state(self, f): + def get_state(self, f=None): + s_global = self.sim.cur_step_global if self.is_active: - state = RigidSolverState(self._scene) + if s_global in self._queried_states: + return self._queried_states[s_global][0] + + state = RigidSolverState(self._scene, s_global) # qpos: ti.types.ndarray(), # vel: ti.types.ndarray(), @@ -1245,6 +1502,7 @@ def get_state(self, f): kernel_get_state( qpos=state.qpos, vel=state.dofs_vel, + acc=state.dofs_acc, links_pos=state.links_pos, links_quat=state.links_quat, i_pos_shift=state.i_pos_shift, @@ -1256,6 +1514,7 @@ def get_state(self, f): rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, ) + self._queried_states.append(state) else: state = None return state @@ -1266,6 +1525,7 @@ def set_state(self, f, state, envs_idx=None): kernel_set_state( qpos=state.qpos, dofs_vel=state.dofs_vel, + dofs_acc=state.dofs_acc, links_pos=state.links_pos, links_quat=state.links_quat, i_pos_shift=state.i_pos_shift, @@ -1291,6 +1551,7 @@ def set_state(self, f, state, envs_idx=None): entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=False, ) self._errno[None] = 0 @@ -1303,16 +1564,51 @@ def set_state(self, f, state, envs_idx=None): self._cur_step = -1 def process_input(self, in_backward=False): - pass + for entity in self._entities: + entity.process_input(in_backward=in_backward) def process_input_grad(self): - pass + for entity in self._entities: + entity.process_input_grad() def save_ckpt(self, ckpt_name): - pass + # Save ckpt only if we need gradients, because this operation is costly + if self._requires_grad: + if ckpt_name not in self._ckpt: + self._ckpt[ckpt_name] = dict() + + self._ckpt[ckpt_name]["qpos"] = ti_to_numpy(self._rigid_adjoint_cache.qpos) + self._ckpt[ckpt_name]["dofs_vel"] = ti_to_numpy(self._rigid_adjoint_cache.dofs_vel) + self._ckpt[ckpt_name]["dofs_acc"] = ti_to_numpy(self._rigid_adjoint_cache.dofs_acc) + + for entity in self._entities: + entity.save_ckpt(ckpt_name) def load_ckpt(self, ckpt_name): - pass + # Set first frame + self._rigid_global_info.qpos.from_numpy(self._ckpt[ckpt_name]["qpos"][0]) + self.dofs_state.vel.from_numpy(self._ckpt[ckpt_name]["dofs_vel"][0]) + self.dofs_state.acc.from_numpy(self._ckpt[ckpt_name]["dofs_acc"][0]) + + if not self._enable_mujoco_compatibility: + kernel_update_cartesian_space( + links_state=self.links_state, + links_info=self.links_info, + joints_state=self.joints_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + geoms_state=self.geoms_state, + geoms_info=self.geoms_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + force_update_fixed_geoms=False, + is_backward=False, + ) + + for entity in self._entities: + entity.load_ckpt(ckpt_name) @property def is_active(self): @@ -1410,6 +1706,15 @@ def _sanitize_1D_io_variables( gs.raise_exception("Expecting 1D output tensor.") return tensor, _inputs_idx, envs_idx + def _sanitize_1D_io_variables_grad( + self, + grad_after_sanitization, + grad_before_sanitization, + ): + if grad_after_sanitization.shape != grad_before_sanitization.shape: + gs.raise_exception("Shape of grad_after_sanitization and grad_before_sanitization do not match.") + return grad_after_sanitization + def _sanitize_2D_io_variables( self, tensor, @@ -1498,6 +1803,15 @@ def _sanitize_2D_io_variables( gs.raise_exception("Expecting 2D input tensor.") return tensor, _inputs_idx, envs_idx + def _sanitize_2D_io_variables_grad( + self, + grad_after_sanitization, + grad_before_sanitization, + ): + if grad_after_sanitization.shape != grad_before_sanitization.shape: + gs.raise_exception("Shape of grad_after_sanitization and grad_before_sanitization do not match.") + return grad_after_sanitization + def _get_qs_idx(self, qs_idx_local=None): return self._get_qs_idx_local(qs_idx_local) + self._q_start @@ -1556,8 +1870,37 @@ def set_base_links_pos( entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=False, ) + def set_base_links_pos_grad(self, links_idx, envs_idx, relative, unsafe, pos_grad): + if links_idx is None: + links_idx = self._base_links_idx + pos_grad_, links_idx, envs_idx = self._sanitize_2D_io_variables( + pos_grad.unsqueeze(-2), + links_idx, + self.n_links, + 3, + envs_idx, + idx_name="links_idx", + skip_allocation=True, + unsafe=unsafe, + ) + if self.n_envs == 0: + pos_grad_ = pos_grad_.unsqueeze(0) + if not unsafe and not torch.isin(links_idx, self._base_links_idx).all(): + gs.raise_exception("`links_idx` contains at least one link that is not a base link.") + kernel_set_links_pos_grad( + relative, + pos_grad_, + links_idx, + envs_idx, + links_info=self.links_info, + links_state=self.links_state, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) + def set_links_quat(self, quat, links_idx=None, envs_idx=None, *, skip_forward=False, unsafe=False): raise DeprecationError("This method has been removed. Please use 'set_base_links_quat' instead.") @@ -1608,8 +1951,38 @@ def set_base_links_quat( entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=False, ) + def set_base_links_quat_grad(self, links_idx, envs_idx, relative, unsafe, quat_grad): + if links_idx is None: + links_idx = self._base_links_idx + quat_grad_, links_idx, envs_idx = self._sanitize_2D_io_variables( + quat_grad.unsqueeze(-2), + links_idx, + self.n_links, + 4, + envs_idx, + idx_name="links_idx", + skip_allocation=True, + unsafe=unsafe, + ) + if self.n_envs == 0: + quat_grad_ = quat_grad_.unsqueeze(0) + if not unsafe and not torch.isin(links_idx, self._base_links_idx).all(): + gs.raise_exception("`links_idx` contains at least one link that is not a base link.") + assert relative == False, "Backward pass for relative quaternion is not supported yet." + kernel_set_links_quat_grad( + relative, + quat_grad_, + links_idx, + envs_idx, + links_info=self.links_info, + links_state=self.links_state, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) + def set_links_mass_shift(self, mass, links_idx=None, envs_idx=None, *, unsafe=False): mass, links_idx, envs_idx = self._sanitize_1D_io_variables( mass, links_idx, self.n_links, envs_idx, idx_name="links_idx", skip_allocation=True, unsafe=unsafe @@ -1686,6 +2059,7 @@ def set_qpos(self, qpos, qs_idx=None, envs_idx=None, *, skip_forward=False, unsa entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=False, ) def set_global_sol_params(self, sol_params, *, unsafe=False): @@ -1874,8 +2248,19 @@ def set_dofs_velocity(self, velocity, dofs_idx=None, envs_idx=None, *, skip_forw entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=False, ) + def set_dofs_velocity_grad(self, dofs_idx, envs_idx, unsafe, velocity_grad): + velocity_grad_, dofs_idx, envs_idx = self._sanitize_1D_io_variables( + velocity_grad, dofs_idx, self.n_dofs, envs_idx, skip_allocation=True, unsafe=unsafe + ) + if self.n_envs == 0: + velocity_grad_ = velocity_grad_.unsqueeze(0) + kernel_set_dofs_velocity_grad( + velocity_grad_, dofs_idx, envs_idx, self.dofs_state, self._static_rigid_sim_config + ) + def set_dofs_position(self, position, dofs_idx=None, envs_idx=None, *, skip_forward=False, unsafe=False): position, dofs_idx, envs_idx = self._sanitize_1D_io_variables( position, dofs_idx, self.n_dofs, envs_idx, skip_allocation=True, unsafe=unsafe @@ -1915,6 +2300,7 @@ def set_dofs_position(self, position, dofs_idx=None, envs_idx=None, *, skip_forw entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=False, ) def control_dofs_force(self, force, dofs_idx=None, envs_idx=None, *, unsafe=False): @@ -2408,18 +2794,36 @@ def n_links(self): return self._n_links return len(self.links) + @property + def max_n_links_per_entity(self): + if self.is_built: + return self._max_n_links_per_entity + return max(len(entity.links) for entity in self._entities) if self._entities else 0 + @property def n_joints(self): if self.is_built: return self._n_joints return len(self.joints) + @property + def max_n_joints_per_link(self): + if self.is_built: + return self._max_n_joints_per_link + return max(len(link.joints) for link in self.links) if self.links else 0 + @property def n_geoms(self): if self.is_built: return self._n_geoms return len(self.geoms) + @property + def max_n_geoms_per_entity(self): + if self.is_built: + return self._max_n_geoms_per_entity + return max(len(link.joints) for link in self.links) if self.links else 0 + @property def n_cells(self): if self.is_built: @@ -2480,12 +2884,36 @@ def n_qs(self): return self._n_qs return sum([entity.n_qs for entity in self._entities]) + @property + def max_n_qs_per_link(self): + if self.is_built: + return self._max_n_qs_per_link + return max(link.n_qs for link in self.links) if self.links else 0 + @property def n_dofs(self): if self.is_built: return self._n_dofs return sum(entity.n_dofs for entity in self._entities) + @property + def max_n_dofs_per_entity(self): + if self.is_built: + return self._max_n_dofs_per_entity + return max(entity.n_dofs for entity in self._entities) if self._entities else 0 + + @property + def max_n_dofs_per_link(self): + if self.is_built: + return self._max_n_dofs_per_link + return max(link.n_dofs for link in self.links) if self.links else 0 + + @property + def max_n_dofs_per_joint(self): + if self.is_built: + return self._max_n_dofs_per_joint + return max(joint.n_dofs for joint in self.joints) if self.joints else 0 + @property def init_qpos(self): if self._entities: @@ -2514,25 +2942,37 @@ def update_qacc_from_qvel_delta( dofs_state: array_class.DofsState, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): n_dofs = dofs_state.ctrl_mode.shape[0] _B = dofs_state.ctrl_mode.shape[1] - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_d_ in range(rigid_global_info.n_awake_dofs[i_b]): - i_d = rigid_global_info.awake_dofs[i_d_, i_b] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ti.ndrange(1, _B) if ti.static(static_rigid_sim_config.use_hibernation) else ti.ndrange(n_dofs, _B): + for i_1 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_dofs[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_dofs)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + ): + if i_1 < (rigid_global_info.n_awake_dofs[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1): + i_d = ( + rigid_global_info.awake_dofs[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) dofs_state.acc[i_d, i_b] = ( dofs_state.vel[i_d, i_b] - dofs_state.vel_prev[i_d, i_b] ) / rigid_global_info.substep_dt[None] dofs_state.vel[i_d, i_b] = dofs_state.vel_prev[i_d, i_b] - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_d, i_b in ti.ndrange(n_dofs, _B): - dofs_state.acc[i_d, i_b] = ( - dofs_state.vel[i_d, i_b] - dofs_state.vel_prev[i_d, i_b] - ) / rigid_global_info.substep_dt[None] - dofs_state.vel[i_d, i_b] = dofs_state.vel_prev[i_d, i_b] @ti.kernel @@ -2540,25 +2980,38 @@ def update_qvel( dofs_state: array_class.DofsState, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): _B = dofs_state.vel.shape[1] n_dofs = dofs_state.vel.shape[0] - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_d_ in range(rigid_global_info.n_awake_dofs[i_b]): - i_d = rigid_global_info.awake_dofs[i_d_, i_b] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ti.ndrange(1, _B) if ti.static(static_rigid_sim_config.use_hibernation) else ti.ndrange(n_dofs, _B): + for i_1 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_dofs[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_dofs)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + ): + if i_1 < (rigid_global_info.n_awake_dofs[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1): + i_d = ( + rigid_global_info.awake_dofs[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) dofs_state.vel_prev[i_d, i_b] = dofs_state.vel[i_d, i_b] dofs_state.vel[i_d, i_b] = ( dofs_state.vel[i_d, i_b] + dofs_state.acc[i_d, i_b] * rigid_global_info.substep_dt[None] ) - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_d, i_b in ti.ndrange(n_dofs, _B): - dofs_state.vel_prev[i_d, i_b] = dofs_state.vel[i_d, i_b] - dofs_state.vel[i_d, i_b] = ( - dofs_state.vel[i_d, i_b] + dofs_state.acc[i_d, i_b] * rigid_global_info.substep_dt[None] - ) @ti.kernel(fastcache=gs.use_fastcache) @@ -2572,6 +3025,7 @@ def kernel_compute_mass_matrix( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), decompose: ti.template(), + is_backward: ti.template(), ): func_compute_mass_matrix( implicit_damping=False, @@ -2582,6 +3036,7 @@ def kernel_compute_mass_matrix( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) if decompose: func_factor_mass( @@ -2591,6 +3046,7 @@ def kernel_compute_mass_matrix( dofs_info=dofs_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) @@ -3157,6 +3613,7 @@ def kernel_forward_dynamics( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), contact_island_state: array_class.ContactIslandState, + is_backward: ti.template(), ): func_forward_dynamics( links_state=links_state, @@ -3170,6 +3627,7 @@ def kernel_forward_dynamics( rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, contact_island_state=contact_island_state, + is_backward=is_backward, ) @@ -3181,6 +3639,7 @@ def kernel_update_acc( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): func_update_acc( update_cacc=True, @@ -3190,6 +3649,7 @@ def kernel_update_acc( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) @@ -3214,170 +3674,232 @@ def func_compute_mass_matrix( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): - _B = links_state.pos.shape[1] - n_links = links_state.pos.shape[0] - n_entities = entities_info.n_links.shape[0] - n_dofs = dofs_state.f_ang.shape[0] - - if ti.static(static_rigid_sim_config.use_hibernation): - # crb initialize - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] + # crb initialize + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ( + ti.ndrange(1, links_state.pos.shape[1]) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.ndrange(links_state.pos.shape[0], links_state.pos.shape[1]) + ): + for i_1 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_links)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + ): + if i_1 < ( + rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): + i_l = ( + rigid_global_info.awake_links[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) + links_state.crb_inertial[i_l, i_b] = links_state.cinr_inertial[i_l, i_b] links_state.crb_pos[i_l, i_b] = links_state.cinr_pos[i_l, i_b] links_state.crb_quat[i_l, i_b] = links_state.cinr_quat[i_l, i_b] links_state.crb_mass[i_l, i_b] = links_state.cinr_mass[i_l, i_b] - # crb - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): - i_e = rigid_global_info.awake_entities[i_e_, i_b] - for i_l_ in range(entities_info.n_links[i_e]): - i_l = entities_info.link_end[i_e] - 1 - i_l_ - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - i_p = links_info.parent_idx[I_l] - - if i_p != -1: - links_state.crb_inertial[i_p, i_b] = ( - links_state.crb_inertial[i_p, i_b] + links_state.crb_inertial[i_l, i_b] + # crb + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ( + ti.ndrange(1, links_state.pos.shape[1]) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.ndrange(entities_info.n_links.shape[0], links_state.pos.shape[1]) + ): + for i_1 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_entities[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_entities)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + ): + if i_1 < ( + rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): + i_e = ( + rigid_global_info.awake_entities[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) + + for i in ( + range(entities_info.n_links[i_e]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) + ): + if i < entities_info.n_links[i_e]: + i_l = entities_info.link_end[i_e] - 1 - i + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + i_p = links_info.parent_idx[I_l] + + if i_p != -1: + links_state.crb_inertial[i_p, i_b] += links_state.crb_inertial[i_l, i_b] + links_state.crb_mass[i_p, i_b] += links_state.crb_mass[i_l, i_b] + links_state.crb_pos[i_p, i_b] += links_state.crb_pos[i_l, i_b] + links_state.crb_quat[i_p, i_b] += links_state.crb_quat[i_l, i_b] + + # mass_mat + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ( + ti.ndrange(1, links_state.pos.shape[1]) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.ndrange(links_state.pos.shape[0], links_state.pos.shape[1]) + ): + for i_1 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_links)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + ): + if i_1 < ( + rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): + i_l = ( + rigid_global_info.awake_links[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + + for i_d_ in ( + range(links_info.dof_start[I_l], links_info.dof_end[I_l]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_link)) + ): + i_d = i_d_ if ti.static(not is_backward) else links_info.dof_start[I_l] + i_d_ + + if i_d < links_info.dof_end[I_l]: + dofs_state.f_ang[i_d, i_b], dofs_state.f_vel[i_d, i_b] = gu.inertial_mul( + links_state.crb_pos[i_l, i_b], + links_state.crb_inertial[i_l, i_b], + links_state.crb_mass[i_l, i_b], + dofs_state.cdof_vel[i_d, i_b], + dofs_state.cdof_ang[i_d, i_b], ) - links_state.crb_mass[i_p, i_b] = links_state.crb_mass[i_p, i_b] + links_state.crb_mass[i_l, i_b] - links_state.crb_pos[i_p, i_b] = links_state.crb_pos[i_p, i_b] + links_state.crb_pos[i_l, i_b] - links_state.crb_quat[i_p, i_b] = links_state.crb_quat[i_p, i_b] + links_state.crb_quat[i_l, i_b] + ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) + for i_0, i_b in ( + ti.ndrange(1, links_state.pos.shape[1]) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.ndrange(entities_info.n_links.shape[0], links_state.pos.shape[1]) + ): + for i_1 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_entities[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_entities)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + ): + if i_1 < ( + rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): + i_e = ( + rigid_global_info.awake_entities[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) - # mass_mat - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - for i_d in range(links_info.dof_start[I_l], links_info.dof_end[I_l]): - dofs_state.f_ang[i_d, i_b], dofs_state.f_vel[i_d, i_b] = gu.inertial_mul( - links_state.crb_pos[i_l, i_b], - links_state.crb_inertial[i_l, i_b], - links_state.crb_mass[i_l, i_b], - dofs_state.cdof_vel[i_d, i_b], - dofs_state.cdof_ang[i_d, i_b], + for i_d_, j_d_ in ( + ( + # Dynamic inner loop for forward pass + ti.ndrange( + (entities_info.dof_start[i_e], entities_info.dof_end[i_e]), + (entities_info.dof_start[i_e], entities_info.dof_end[i_e]), + ) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static( + ti.ndrange( + static_rigid_sim_config.max_n_dofs_per_entity, + static_rigid_sim_config.max_n_dofs_per_entity, + ) + ) ) + ): + i_d = i_d_ if ti.static(not is_backward) else entities_info.dof_start[i_e] + i_d_ + j_d = j_d_ if ti.static(not is_backward) else entities_info.dof_start[i_e] + j_d_ - ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) - for i_b in range(_B): - for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): - i_e = rigid_global_info.awake_entities[i_e_, i_b] - for i_d in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): - for j_d in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): + if i_d < entities_info.dof_end[i_e] and j_d < entities_info.dof_end[i_e]: rigid_global_info.mass_mat[i_d, j_d, i_b] = ( dofs_state.f_ang[i_d, i_b].dot(dofs_state.cdof_ang[j_d, i_b]) + dofs_state.f_vel[i_d, i_b].dot(dofs_state.cdof_vel[j_d, i_b]) ) * rigid_global_info.mass_parent_mask[i_d, j_d] - # FIXME: Updating the lower-part of the mass matrix is irrelevant - for i_d in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): - for j_d in range(i_d + 1, entities_info.dof_end[i_e]): - rigid_global_info.mass_mat[i_d, j_d, i_b] = rigid_global_info.mass_mat[j_d, i_d, i_b] - - # Take into account motor armature - for i_d in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - rigid_global_info.mass_mat[i_d, i_d, i_b] = ( - rigid_global_info.mass_mat[i_d, i_d, i_b] + dofs_info.armature[I_d] - ) - - # Take into account first-order correction terms for implicit integration scheme right away - if ti.static(implicit_damping): + if ti.static(not is_backward): for i_d in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - rigid_global_info.mass_mat[i_d, i_d, i_b] += ( - dofs_info.damping[I_d] * rigid_global_info.substep_dt[None] + for j_d in range(i_d + 1, entities_info.dof_end[i_e]): + rigid_global_info.mass_mat[i_d, j_d, i_b] = rigid_global_info.mass_mat[j_d, i_d, i_b] + else: + for i_d_, j_d_ in ti.static( + ti.ndrange( + static_rigid_sim_config.max_n_dofs_per_entity, + static_rigid_sim_config.max_n_dofs_per_entity, ) - if ( - dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.POSITION - or dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.VELOCITY - ): - # qM += d qfrc_actuator / d qvel - rigid_global_info.mass_mat[i_d, i_d, i_b] += ( - dofs_info.kv[I_d] * rigid_global_info.substep_dt[None] - ) - else: - # crb initialize - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l, i_b in ti.ndrange(n_links, _B): - links_state.crb_inertial[i_l, i_b] = links_state.cinr_inertial[i_l, i_b] - links_state.crb_pos[i_l, i_b] = links_state.cinr_pos[i_l, i_b] - links_state.crb_quat[i_l, i_b] = links_state.cinr_quat[i_l, i_b] - links_state.crb_mass[i_l, i_b] = links_state.cinr_mass[i_l, i_b] - - # crb - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_e, i_b in ti.ndrange(n_entities, _B): - for i_l_ in range(entities_info.n_links[i_e]): - i_l = entities_info.link_end[i_e] - 1 - i_l_ - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - i_p = links_info.parent_idx[I_l] - - if i_p != -1: - links_state.crb_inertial[i_p, i_b] = ( - links_state.crb_inertial[i_p, i_b] + links_state.crb_inertial[i_l, i_b] - ) - links_state.crb_mass[i_p, i_b] = links_state.crb_mass[i_p, i_b] + links_state.crb_mass[i_l, i_b] - - links_state.crb_pos[i_p, i_b] = links_state.crb_pos[i_p, i_b] + links_state.crb_pos[i_l, i_b] - links_state.crb_quat[i_p, i_b] = links_state.crb_quat[i_p, i_b] + links_state.crb_quat[i_l, i_b] - - # mass_mat - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l, i_b in ti.ndrange(n_links, _B): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - for i_d in range(links_info.dof_start[I_l], links_info.dof_end[I_l]): - dofs_state.f_ang[i_d, i_b], dofs_state.f_vel[i_d, i_b] = gu.inertial_mul( - links_state.crb_pos[i_l, i_b], - links_state.crb_inertial[i_l, i_b], - links_state.crb_mass[i_l, i_b], - dofs_state.cdof_vel[i_d, i_b], - dofs_state.cdof_ang[i_d, i_b], - ) + ): + i_d = entities_info.dof_start[i_e] + i_d_ + j_d = entities_info.dof_start[i_e] + j_d_ - ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) - for i_e, i_b in ti.ndrange(n_entities, _B): - for i_d, j_d in ti.ndrange( - (entities_info.dof_start[i_e], entities_info.dof_end[i_e]), - (entities_info.dof_start[i_e], entities_info.dof_end[i_e]), - ): - rigid_global_info.mass_mat[i_d, j_d, i_b] = ( - dofs_state.f_ang[i_d, i_b].dot(dofs_state.cdof_ang[j_d, i_b]) - + dofs_state.f_vel[i_d, i_b].dot(dofs_state.cdof_vel[j_d, i_b]) - ) * rigid_global_info.mass_parent_mask[i_d, j_d] + if i_d < entities_info.dof_end[i_e] and j_d < entities_info.dof_end[i_e] and j_d > i_d: + rigid_global_info.mass_mat[i_d, j_d, i_b] = rigid_global_info.mass_mat[j_d, i_d, i_b] - # FIXME: Updating the lower-part of the mass matrix is irrelevant - for i_d in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): - for j_d in range(i_d + 1, entities_info.dof_end[i_e]): - rigid_global_info.mass_mat[i_d, j_d, i_b] = rigid_global_info.mass_mat[j_d, i_d, i_b] + # Take into account motor armature + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_d, i_b in ti.ndrange(dofs_state.f_ang.shape[0], links_state.pos.shape[1]): + I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d + rigid_global_info.mass_mat[i_d, i_d, i_b] += dofs_info.armature[I_d] - # Take into account motor armature + # Take into account first-order correction terms for implicit integration scheme right away + if ti.static(implicit_damping): ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_d, i_b in ti.ndrange(n_dofs, _B): + for i_d, i_b in ti.ndrange(dofs_state.f_ang.shape[0], links_state.pos.shape[1]): I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - rigid_global_info.mass_mat[i_d, i_d, i_b] = ( - rigid_global_info.mass_mat[i_d, i_d, i_b] + dofs_info.armature[I_d] - ) - - # Take into account first-order correction terms for implicit integration scheme right away - if ti.static(implicit_damping): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_d, i_b in ti.ndrange(n_dofs, _B): - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - rigid_global_info.mass_mat[i_d, i_d, i_b] += dofs_info.damping[I_d] * rigid_global_info.substep_dt[None] - if ( - dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.POSITION - or dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.VELOCITY - ): - # qM += d qfrc_actuator / d qvel - rigid_global_info.mass_mat[i_d, i_d, i_b] += dofs_info.kv[I_d] * rigid_global_info.substep_dt[None] + rigid_global_info.mass_mat[i_d, i_d, i_b] += dofs_info.damping[I_d] * rigid_global_info.substep_dt[None] + if ( + dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.POSITION + or dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.VELOCITY + ): + # qM += d qfrc_actuator / d qvel + rigid_global_info.mass_mat[i_d, i_d, i_b] += dofs_info.kv[I_d] * rigid_global_info.substep_dt[None] @ti.func @@ -3388,115 +3910,281 @@ def func_factor_mass( dofs_info: array_class.DofsInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): """ Compute Cholesky decomposition (L^T @ D @ L) of mass matrix. """ - _B = dofs_state.ctrl_mode.shape[1] - n_entities = entities_info.n_links.shape[0] + if ti.static(not is_backward): + _B = dofs_state.ctrl_mode.shape[1] + n_entities = entities_info.n_links.shape[0] - if ti.static(static_rigid_sim_config.use_hibernation): ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) - for i_b in range(_B): - for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): - i_e = rigid_global_info.awake_entities[i_e_, i_b] + for i_0, i_b in ( + ti.ndrange(1, _B) if ti.static(static_rigid_sim_config.use_hibernation) else ti.ndrange(n_entities, _B) + ): + for i_1 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_entities[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_entities)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + ): + if i_1 < ( + rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): + i_e = ( + rigid_global_info.awake_entities[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) - if rigid_global_info.mass_mat_mask[i_e, i_b]: - entity_dof_start = entities_info.dof_start[i_e] - entity_dof_end = entities_info.dof_end[i_e] - n_dofs = entities_info.n_dofs[i_e] + if rigid_global_info.mass_mat_mask[i_e, i_b]: + entity_dof_start = entities_info.dof_start[i_e] + entity_dof_end = entities_info.dof_end[i_e] + n_dofs = entities_info.n_dofs[i_e] - for i_d in range(entity_dof_start, entity_dof_end): - for j_d in range(entity_dof_start, i_d + 1): - rigid_global_info.mass_mat_L[i_d, j_d, i_b] = rigid_global_info.mass_mat[i_d, j_d, i_b] + for i_d_ in ( + range(entity_dof_start, entity_dof_end) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + i_d = i_d_ if ti.static(not is_backward) else entities_info.dof_start[i_e] + i_d_ - if ti.static(implicit_damping): - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - rigid_global_info.mass_mat_L[i_d, i_d, i_b] += ( - dofs_info.damping[I_d] * rigid_global_info.substep_dt[None] - ) - if ti.static(static_rigid_sim_config.integrator == gs.integrator.implicitfast): - if ( - dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.POSITION - or dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.VELOCITY + if i_d < entity_dof_end: + for j_d_ in ( + range(entity_dof_start, i_d + 1) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) ): + j_d = j_d_ if ti.static(not is_backward) else entities_info.dof_start[i_e] + j_d_ + + if j_d < i_d + 1: + rigid_global_info.mass_mat_L[i_d, j_d, i_b] = rigid_global_info.mass_mat[ + i_d, j_d, i_b + ] + + if ti.static(implicit_damping): + I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d rigid_global_info.mass_mat_L[i_d, i_d, i_b] += ( - dofs_info.kv[I_d] * rigid_global_info.substep_dt[None] + dofs_info.damping[I_d] * rigid_global_info.substep_dt[None] ) - - for i_d_ in range(n_dofs): - i_d = entity_dof_end - i_d_ - 1 - rigid_global_info.mass_mat_D_inv[i_d, i_b] = 1.0 / rigid_global_info.mass_mat_L[i_d, i_d, i_b] - - for j_d_ in range(i_d - entity_dof_start): - j_d = i_d - j_d_ - 1 - a = rigid_global_info.mass_mat_L[i_d, j_d, i_b] * rigid_global_info.mass_mat_D_inv[i_d, i_b] - for k_d in range(entity_dof_start, j_d + 1): - rigid_global_info.mass_mat_L[j_d, k_d, i_b] -= ( - a * rigid_global_info.mass_mat_L[i_d, k_d, i_b] + if ti.static(static_rigid_sim_config.integrator == gs.integrator.implicitfast): + if (dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.POSITION) or ( + dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.VELOCITY + ): + rigid_global_info.mass_mat_L[i_d, i_d, i_b] += ( + dofs_info.kv[I_d] * rigid_global_info.substep_dt[None] + ) + + for i_d_ in ( + range(n_dofs) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + if i_d_ < n_dofs: + i_d = entity_dof_end - i_d_ - 1 + rigid_global_info.mass_mat_D_inv[i_d, i_b] = ( + 1.0 / rigid_global_info.mass_mat_L[i_d, i_d, i_b] ) - rigid_global_info.mass_mat_L[i_d, j_d, i_b] = a - # FIXME: Diagonal coeffs of L are ignored in computations, so no need to update them. - rigid_global_info.mass_mat_L[i_d, i_d, i_b] = 1.0 + for j_d_ in ( + range(i_d - entity_dof_start) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + if j_d_ < i_d - entity_dof_start: + j_d = i_d - j_d_ - 1 + a = ( + rigid_global_info.mass_mat_L[i_d, j_d, i_b] + * rigid_global_info.mass_mat_D_inv[i_d, i_b] + ) + + for k_d_ in ( + range(entity_dof_start, j_d + 1) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + k_d = ( + k_d_ + if ti.static(not is_backward) + else entities_info.dof_start[i_e] + k_d_ + ) + if k_d < j_d + 1: + rigid_global_info.mass_mat_L[j_d, k_d, i_b] -= ( + a * rigid_global_info.mass_mat_L[i_d, k_d, i_b] + ) + rigid_global_info.mass_mat_L[i_d, j_d, i_b] = a + + # FIXME: Diagonal coeffs of L are ignored in computations, so no need to update them. + rigid_global_info.mass_mat_L[i_d, i_d, i_b] = 1.0 + else: - ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) - for i_e, i_b in ti.ndrange(n_entities, _B): + # Cholesky decomposition that has safe access pattern and robust handling of divide by zero for AD. Even though + # it is logically equivalent to the above block, it shows slightly numerical difference in the result, and thus + # it fails for a unit test ("test_urdf_rope"), while passing all the others. TODO: Investigate if we can fix this + # and only use this block. + + # Assume this is the outermost loop + ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL)) + for i_e, i_b in ti.ndrange(entities_info.n_links.shape[0], dofs_state.ctrl_mode.shape[1]): if rigid_global_info.mass_mat_mask[i_e, i_b]: entity_dof_start = entities_info.dof_start[i_e] entity_dof_end = entities_info.dof_end[i_e] n_dofs = entities_info.n_dofs[i_e] - for i_d in range(entity_dof_start, entity_dof_end): - for j_d in range(entity_dof_start, i_d + 1): - rigid_global_info.mass_mat_L[i_d, j_d, i_b] = rigid_global_info.mass_mat[i_d, j_d, i_b] + for i_d0 in ( + range(n_dofs) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + if i_d0 < n_dofs: + i_d = entity_dof_start + i_d0 + i_pr = (entity_dof_start + entity_dof_end - 1) - i_d + for j_d_ in ( + range(entity_dof_start, i_d + 1) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + j_d = j_d_ if ti.static(not is_backward) else (j_d_ + entities_info.dof_start[i_e]) + j_pr = (entity_dof_start + entity_dof_end - 1) - j_d + if j_d < i_d + 1: + rigid_global_info.mass_mat_L_bw[0, i_pr, j_pr, i_b] = rigid_global_info.mass_mat[ + i_d, j_d, i_b + ] + rigid_global_info.mass_mat_L_bw[0, j_pr, i_pr, i_b] = rigid_global_info.mass_mat[ + i_d, j_d, i_b + ] + + if ti.static(implicit_damping): + I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d + rigid_global_info.mass_mat_L_bw[0, i_pr, i_pr, i_b] += ( + dofs_info.damping[I_d] * rigid_global_info.substep_dt[None] + ) + if ti.static(static_rigid_sim_config.integrator == gs.integrator.implicitfast): + if ( + dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.POSITION + or dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.VELOCITY + ): + rigid_global_info.mass_mat_L_bw[0, i_pr, i_pr, i_b] += ( + dofs_info.kv[I_d] * rigid_global_info.substep_dt[None] + ) - if ti.static(implicit_damping): - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - rigid_global_info.mass_mat_L[i_d, i_d, i_b] += ( - dofs_info.damping[I_d] * rigid_global_info.substep_dt[None] - ) - if ti.static(static_rigid_sim_config.integrator == gs.integrator.implicitfast): - if ( - dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.POSITION - or dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.VELOCITY + # Cholesky-Banachiewicz algorithm (in the perturbed indices), access pattern is safe for autodiff + # https://en.wikipedia.org/wiki/Cholesky_decomposition + for p_i0 in ( + range(n_dofs) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + for p_j0 in ( + range(p_i0 + 1) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + if p_i0 < n_dofs and p_j0 < n_dofs and p_j0 <= p_i0: + # j_pr <= i_pr + i_pr = entity_dof_start + p_i0 + j_pr = entity_dof_start + p_j0 + + sum = gs.ti_float(0.0) + for p_k0 in ( + range(p_j0) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) ): - rigid_global_info.mass_mat_L[i_d, i_d, i_b] += ( - dofs_info.kv[I_d] * rigid_global_info.substep_dt[None] - ) - - for i_d_ in range(n_dofs): - i_d = entity_dof_end - i_d_ - 1 - rigid_global_info.mass_mat_D_inv[i_d, i_b] = 1.0 / rigid_global_info.mass_mat_L[i_d, i_d, i_b] + # k_pr < j_pr + if p_k0 < p_j0: + k_pr = entity_dof_start + p_k0 + sum += ( + rigid_global_info.mass_mat_L_bw[1, i_pr, k_pr, i_b] + * rigid_global_info.mass_mat_L_bw[1, j_pr, k_pr, i_b] + ) - for j_d_ in range(i_d - entity_dof_start): - j_d = i_d - j_d_ - 1 - a = rigid_global_info.mass_mat_L[i_d, j_d, i_b] * rigid_global_info.mass_mat_D_inv[i_d, i_b] - for k_d in range(entity_dof_start, j_d + 1): - rigid_global_info.mass_mat_L[j_d, k_d, i_b] -= ( - a * rigid_global_info.mass_mat_L[i_d, k_d, i_b] - ) - rigid_global_info.mass_mat_L[i_d, j_d, i_b] = a + a = rigid_global_info.mass_mat_L_bw[0, i_pr, j_pr, i_b] - sum + b = ti.math.clamp(rigid_global_info.mass_mat_L_bw[1, j_pr, j_pr, i_b], gs.EPS, ti.math.inf) + if p_i0 == p_j0: + rigid_global_info.mass_mat_L_bw[1, i_pr, j_pr, i_b] = ti.sqrt( + ti.math.clamp(a, gs.EPS, ti.math.inf) + ) + else: + rigid_global_info.mass_mat_L_bw[1, i_pr, j_pr, i_b] = a / b - # FIXME: Diagonal coeffs of L are ignored in computations, so no need to update them. - rigid_global_info.mass_mat_L[i_d, i_d, i_b] = 1.0 + for i_d0 in ( + range(n_dofs) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + for i_d1 in ( + range(i_d0 + 1) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + if i_d0 < n_dofs and i_d1 < n_dofs and i_d1 <= i_d0: + i_d = entity_dof_start + i_d0 + j_d = entity_dof_start + i_d1 + i_pr = (entity_dof_start + entity_dof_end - 1) - i_d + j_pr = (entity_dof_start + entity_dof_end - 1) - j_d + + a = rigid_global_info.mass_mat_L_bw[1, i_pr, i_pr, i_b] + rigid_global_info.mass_mat_L[i_d, j_d, i_b] = rigid_global_info.mass_mat_L_bw[ + 1, j_pr, i_pr, i_b + ] / ti.math.clamp(a, gs.EPS, ti.math.inf) + + if i_d == j_d: + rigid_global_info.mass_mat_D_inv[i_d, i_b] = 1.0 / ( + ti.math.clamp(a**2, gs.EPS, ti.math.inf) + ) @ti.func def func_solve_mass_batched( vec: array_class.V_ANNOTATION, out: array_class.V_ANNOTATION, + out_bw: array_class.V_ANNOTATION, i_b: ti.int32, entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): + # This loop is considered an inner loop + ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) + for i_0 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_entities[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(entities_info.n_links.shape[0]) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_entities)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(entities_info.n_links.shape[0])) + ) + ): + n_entities = entities_info.n_links.shape[0] - n_entities = entities_info.n_links.shape[0] - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) - for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): - i_e = rigid_global_info.awake_entities[i_e_, i_b] + if i_0 < ( + rigid_global_info.n_awake_entities[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else n_entities + ): + i_e = ( + rigid_global_info.awake_entities[i_0, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) if rigid_global_info.mass_mat_mask[i_e, i_b]: entity_dof_start = entities_info.dof_start[i_e] @@ -3504,63 +4192,93 @@ def func_solve_mass_batched( n_dofs = entities_info.n_dofs[i_e] # Step 1: Solve w st. L^T @ w = y - for i_d_ in range(n_dofs): - i_d = entity_dof_end - i_d_ - 1 - out[i_d, i_b] = vec[i_d, i_b] - for j_d in range(i_d + 1, entity_dof_end): - out[i_d, i_b] -= rigid_global_info.mass_mat_L[j_d, i_d, i_b] * out[j_d, i_b] + for i_d_ in ( + range(n_dofs) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + if i_d_ < n_dofs: + i_d = entity_dof_end - i_d_ - 1 + if ti.static(is_backward): + out_bw[0, i_d, i_b] = vec[i_d, i_b] + else: + out[i_d, i_b] = vec[i_d, i_b] + + for j_d_ in ( + range(i_d + 1, entity_dof_end) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + j_d = j_d_ if ti.static(not is_backward) else (j_d_ + entities_info.dof_start[i_e]) + if j_d >= i_d + 1 and j_d < entity_dof_end: + # Since we read out[j_d, i_b], and j_d > i_d, which means that out[j_d, i_b] is already + # finalized at this point, we don't need to care about AD mutation rule. + if ti.static(is_backward): + out_bw[0, i_d, i_b] += -( + rigid_global_info.mass_mat_L[j_d, i_d, i_b] * out_bw[0, j_d, i_b] + ) + else: + out[i_d, i_b] += -(rigid_global_info.mass_mat_L[j_d, i_d, i_b] * out[j_d, i_b]) # Step 2: z = D^{-1} w - for i_d in range(entity_dof_start, entity_dof_end): - out[i_d, i_b] *= rigid_global_info.mass_mat_D_inv[i_d, i_b] + for i_d_ in ( + range(entity_dof_start, entity_dof_end) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + i_d = i_d_ if ti.static(not is_backward) else (i_d_ + entities_info.dof_start[i_e]) + if i_d < entity_dof_end: + if ti.static(is_backward): + out_bw[1, i_d, i_b] = out_bw[0, i_d, i_b] * rigid_global_info.mass_mat_D_inv[i_d, i_b] + else: + out[i_d, i_b] *= rigid_global_info.mass_mat_D_inv[i_d, i_b] # Step 3: Solve x st. L @ x = z - for i_d in range(entity_dof_start, entity_dof_end): - for j_d in range(entity_dof_start, i_d): - out[i_d, i_b] -= rigid_global_info.mass_mat_L[i_d, j_d, i_b] * out[j_d, i_b] - else: - ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) - for i_e in range(n_entities): - if rigid_global_info.mass_mat_mask[i_e, i_b]: - entity_dof_start = entities_info.dof_start[i_e] - entity_dof_end = entities_info.dof_end[i_e] - n_dofs = entities_info.n_dofs[i_e] - - # Step 1: Solve w st. L^T @ w = y - for i_d_ in range(n_dofs): - i_d = entity_dof_end - i_d_ - 1 - out[i_d, i_b] = vec[i_d, i_b] - for j_d in range(i_d + 1, entity_dof_end): - out[i_d, i_b] -= rigid_global_info.mass_mat_L[j_d, i_d, i_b] * out[j_d, i_b] - - # Step 2: z = D^{-1} w - for i_d in range(entity_dof_start, entity_dof_end): - out[i_d, i_b] *= rigid_global_info.mass_mat_D_inv[i_d, i_b] + for i_d_ in ( + range(entity_dof_start, entity_dof_end) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + i_d = i_d_ if ti.static(not is_backward) else (i_d_ + entities_info.dof_start[i_e]) + if i_d < entity_dof_end: + curr_out = out[i_d, i_b] + if ti.static(is_backward): + curr_out = out_bw[1, i_d, i_b] + + for j_d_ in ( + range(entity_dof_start, i_d) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + j_d = j_d_ if ti.static(not is_backward) else (j_d_ + entities_info.dof_start[i_e]) + if j_d < i_d: + curr_out += -(rigid_global_info.mass_mat_L[i_d, j_d, i_b] * out[j_d, i_b]) - # Step 3: Solve x st. L @ x = z - for i_d in range(entity_dof_start, entity_dof_end): - for j_d in range(entity_dof_start, i_d): - out[i_d, i_b] -= rigid_global_info.mass_mat_L[i_d, j_d, i_b] * out[j_d, i_b] + out[i_d, i_b] = curr_out @ti.func def func_solve_mass( vec: array_class.V_ANNOTATION, out: array_class.V_ANNOTATION, + out_bw: array_class.V_ANNOTATION, # Should not be None if backward entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): - _B = out.shape[1] + # This loop must be the outermost loop to be differentiable ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) - for i_b in range(_B): + for i_b in range(out.shape[1]): func_solve_mass_batched( vec, out, + out_bw, i_b, entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) @@ -3579,6 +4297,7 @@ def func_forward_dynamics( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), contact_island_state: array_class.ContactIslandState, + is_backward: ti.template(), ): func_compute_mass_matrix( implicit_damping=ti.static(static_rigid_sim_config.integrator == gs.integrator.approximate_implicitfast), @@ -3589,6 +4308,7 @@ def func_forward_dynamics( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) func_factor_mass( implicit_damping=False, @@ -3597,6 +4317,7 @@ def func_forward_dynamics( dofs_info=dofs_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) func_torque_and_passive_force( entities_state=entities_state, @@ -3610,6 +4331,7 @@ def func_forward_dynamics( rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, contact_island_state=contact_island_state, + is_backward=is_backward, ) func_update_acc( update_cacc=False, @@ -3619,6 +4341,7 @@ def func_forward_dynamics( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) func_update_force( links_state=links_state, @@ -3626,6 +4349,7 @@ def func_forward_dynamics( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) # self._func_actuation() func_bias_force( @@ -3634,95 +4358,217 @@ def func_forward_dynamics( links_info=links_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) func_compute_qacc( dofs_state=dofs_state, entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) -@ti.kernel(fastcache=gs.use_fastcache) -def kernel_clear_external_force( - links_state: array_class.LinksState, - rigid_global_info: array_class.RigidGlobalInfo, - static_rigid_sim_config: ti.template(), -): - func_clear_external_force( - links_state=links_state, - rigid_global_info=rigid_global_info, - static_rigid_sim_config=static_rigid_sim_config, - ) - - -@ti.func -def func_update_cartesian_space( - i_b, +@ti.kernel +def kernel_forward_dynamics_without_qacc( links_state: array_class.LinksState, links_info: array_class.LinksInfo, - joints_state: array_class.JointsState, - joints_info: array_class.JointsInfo, dofs_state: array_class.DofsState, dofs_info: array_class.DofsInfo, - geoms_info: array_class.GeomsInfo, - geoms_state: array_class.GeomsState, + joints_info: array_class.JointsInfo, + entities_state: array_class.EntitiesState, entities_info: array_class.EntitiesInfo, + geoms_state: array_class.GeomsState, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), - force_update_fixed_geoms: ti.template(), + contact_island_state: array_class.ContactIslandState, + is_backward: ti.template(), ): - func_forward_kinematics( - i_b, + func_compute_mass_matrix( + implicit_damping=ti.static(static_rigid_sim_config.integrator == gs.integrator.approximate_implicitfast), links_state=links_state, links_info=links_info, - joints_state=joints_state, - joints_info=joints_info, dofs_state=dofs_state, dofs_info=dofs_info, entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) - func_COM_links( - i_b, - links_state=links_state, - links_info=links_info, - joints_state=joints_state, - joints_info=joints_info, + func_factor_mass( + implicit_damping=False, + entities_info=entities_info, dofs_state=dofs_state, dofs_info=dofs_info, - entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) - func_forward_velocity( - i_b, + func_torque_and_passive_force( + entities_state=entities_state, entities_info=entities_info, - links_info=links_info, + dofs_state=dofs_state, + dofs_info=dofs_info, links_state=links_state, + links_info=links_info, joints_info=joints_info, + geoms_state=geoms_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + contact_island_state=contact_island_state, + is_backward=is_backward, + ) + func_update_acc( + update_cacc=False, dofs_state=dofs_state, + links_info=links_info, + links_state=links_state, + entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) - - func_update_geoms( - i_b=i_b, + func_update_force( + links_state=links_state, + links_info=links_info, entities_info=entities_info, - geoms_info=geoms_info, - geoms_state=geoms_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, + ) + # self._func_actuation() + func_bias_force( + dofs_state=dofs_state, links_state=links_state, + links_info=links_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - force_update_fixed_geoms=force_update_fixed_geoms, + is_backward=is_backward, ) @ti.kernel(fastcache=gs.use_fastcache) -def kernel_step_1( +def kernel_clear_external_force( links_state: array_class.LinksState, - links_info: array_class.LinksInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + func_clear_external_force( + links_state=links_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + + +@ti.kernel(fastcache=gs.use_fastcache) +def kernel_update_cartesian_space( + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + joints_state: array_class.JointsState, + joints_info: array_class.JointsInfo, + dofs_state: array_class.DofsState, + dofs_info: array_class.DofsInfo, + geoms_info: array_class.GeomsInfo, + geoms_state: array_class.GeomsState, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), + force_update_fixed_geoms: ti.template(), + is_backward: ti.template(), +): + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_b in range(links_state.pos.shape[1]): + func_update_cartesian_space( + i_b=i_b, + links_state=links_state, + links_info=links_info, + joints_state=joints_state, + joints_info=joints_info, + dofs_state=dofs_state, + dofs_info=dofs_info, + geoms_info=geoms_info, + geoms_state=geoms_state, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + force_update_fixed_geoms=force_update_fixed_geoms, + is_backward=is_backward, + ) + + +@ti.func +def func_update_cartesian_space( + i_b, + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + joints_state: array_class.JointsState, + joints_info: array_class.JointsInfo, + dofs_state: array_class.DofsState, + dofs_info: array_class.DofsInfo, + geoms_info: array_class.GeomsInfo, + geoms_state: array_class.GeomsState, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), + force_update_fixed_geoms: ti.template(), + is_backward: ti.template(), +): + func_forward_kinematics( + i_b, + links_state=links_state, + links_info=links_info, + joints_state=joints_state, + joints_info=joints_info, + dofs_state=dofs_state, + dofs_info=dofs_info, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, + ) + func_COM_links( + i_b, + links_state=links_state, + links_info=links_info, + joints_state=joints_state, + joints_info=joints_info, + dofs_state=dofs_state, + dofs_info=dofs_info, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, + ) + func_forward_velocity( + i_b, + entities_info=entities_info, + links_info=links_info, + links_state=links_state, + joints_info=joints_info, + dofs_state=dofs_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, + ) + + func_update_geoms( + i_b=i_b, + entities_info=entities_info, + geoms_info=geoms_info, + geoms_state=geoms_state, + links_state=links_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + force_update_fixed_geoms=force_update_fixed_geoms, + is_backward=is_backward, + ) + + +@ti.kernel(fastcache=gs.use_fastcache) +def kernel_step_1( + f: ti.int32, + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, joints_state: array_class.JointsState, joints_info: array_class.JointsInfo, dofs_state: array_class.DofsState, @@ -3732,9 +4578,21 @@ def kernel_step_1( entities_state: array_class.EntitiesState, entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, + rigid_adjoint_cache: array_class.RigidAdjointCache, static_rigid_sim_config: ti.template(), contact_island_state: array_class.ContactIslandState, + is_backward: ti.template(), ): + if ti.static(static_rigid_sim_config.requires_grad): + if f == 0: + func_save_adjoint_cache( + f=f, + dofs_state=dofs_state, + rigid_global_info=rigid_global_info, + rigid_adjoint_cache=rigid_adjoint_cache, + static_rigid_sim_config=static_rigid_sim_config, + ) + if ti.static(static_rigid_sim_config.enable_mujoco_compatibility): _B = links_state.pos.shape[1] ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) @@ -3753,6 +4611,7 @@ def kernel_step_1( rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, force_update_fixed_geoms=False, + is_backward=is_backward, ) func_forward_dynamics( @@ -3767,6 +4626,7 @@ def kernel_step_1( rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, contact_island_state=contact_island_state, + is_backward=is_backward, ) @@ -3777,6 +4637,7 @@ def func_implicit_damping( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): EPS = rigid_global_info.EPS[None] @@ -3796,16 +4657,22 @@ def func_implicit_damping( for i_e, i_b in ti.ndrange(n_entities, _B): entity_dof_start = entities_info.dof_start[i_e] entity_dof_end = entities_info.dof_end[i_e] - for i_d in range(entity_dof_start, entity_dof_end): - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - if dofs_info.damping[I_d] > EPS: - rigid_global_info.mass_mat_mask[i_e, i_b] = True - if ti.static(static_rigid_sim_config.integrator != gs.integrator.Euler): - if ( - dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.POSITION - or dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.VELOCITY - ) and dofs_info.kv[I_d] > EPS: + for i_d_ in ( + range(entity_dof_start, entity_dof_end) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): + i_d = i_d_ if ti.static(not is_backward) else entities_info.dof_start[i_e] + i_d_ + if i_d < entity_dof_end: + I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d + if dofs_info.damping[I_d] > EPS: rigid_global_info.mass_mat_mask[i_e, i_b] = True + if ti.static(static_rigid_sim_config.integrator != gs.integrator.Euler): + if ( + dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.POSITION + or dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.VELOCITY + ) and dofs_info.kv[I_d] > EPS: + rigid_global_info.mass_mat_mask[i_e, i_b] = True func_factor_mass( implicit_damping=True, @@ -3814,13 +4681,16 @@ def func_implicit_damping( dofs_info=dofs_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) func_solve_mass( vec=dofs_state.force, out=dofs_state.acc, + out_bw=dofs_state.acc_bw, entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) # Disable pre-computed factorization mask right away @@ -3834,6 +4704,7 @@ def func_implicit_damping( @ti.kernel(fastcache=gs.use_fastcache) def kernel_step_2( + f: ti.int32, dofs_state: array_class.DofsState, dofs_info: array_class.DofsInfo, links_info: array_class.LinksInfo, @@ -3846,8 +4717,10 @@ def kernel_step_2( geoms_state: array_class.GeomsState, collider_state: array_class.ColliderState, rigid_global_info: array_class.RigidGlobalInfo, + rigid_adjoint_cache: array_class.RigidAdjointCache, static_rigid_sim_config: ti.template(), contact_island_state: array_class.ContactIslandState, + is_backward: ti.template(), ): # Position, Velocity and Acceleration data must be consistent when computing links acceleration, otherwise it # would not corresponds to anyting physical. There is no other way than doing this right before integration, @@ -3862,6 +4735,7 @@ def kernel_step_2( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) if ti.static(static_rigid_sim_config.integrator != gs.integrator.approximate_implicitfast): @@ -3871,6 +4745,7 @@ def kernel_step_2( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) func_integrate( @@ -3879,6 +4754,7 @@ def kernel_step_2( joints_info=joints_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) if ti.static(static_rigid_sim_config.use_hibernation): @@ -3901,24 +4777,41 @@ def kernel_step_2( static_rigid_sim_config=static_rigid_sim_config, ) - if ti.static(not static_rigid_sim_config.enable_mujoco_compatibility): - _B = links_state.pos.shape[1] - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - func_update_cartesian_space( - i_b=i_b, - links_state=links_state, - links_info=links_info, - joints_state=joints_state, - joints_info=joints_info, + if ti.static(not is_backward): + func_copy_next_to_curr( + dofs_state=dofs_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + + if ti.static(not static_rigid_sim_config.enable_mujoco_compatibility): + _B = links_state.pos.shape[1] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_b in range(_B): + func_update_cartesian_space( + i_b=i_b, + links_state=links_state, + links_info=links_info, + joints_state=joints_state, + joints_info=joints_info, + dofs_state=dofs_state, + dofs_info=dofs_info, + geoms_info=geoms_info, + geoms_state=geoms_state, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + force_update_fixed_geoms=False, + is_backward=is_backward, + ) + + if ti.static(static_rigid_sim_config.requires_grad): + func_save_adjoint_cache( + f=f + 1, dofs_state=dofs_state, - dofs_info=dofs_info, - geoms_info=geoms_info, - geoms_state=geoms_state, - entities_info=entities_info, rigid_global_info=rigid_global_info, + rigid_adjoint_cache=rigid_adjoint_cache, static_rigid_sim_config=static_rigid_sim_config, - force_update_fixed_geoms=False, ) @@ -3936,6 +4829,7 @@ def kernel_forward_kinematics_links_geoms( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): for i_b_ in range(envs_idx.shape[0]): i_b = envs_idx[i_b_] @@ -3954,6 +4848,7 @@ def kernel_forward_kinematics_links_geoms( rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, force_update_fixed_geoms=True, + is_backward=is_backward, ) @@ -3969,26 +4864,65 @@ def func_COM_links( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): - EPS = rigid_global_info.EPS[None] - - n_links = links_info.root_idx.shape[0] - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l_ in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(links_info.root_idx.shape[0]) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_links)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(links_info.root_idx.shape[0])) + ) + ): + if i_l_ < ( + rigid_global_info.n_awake_links[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else links_info.root_idx.shape[0] + ): + i_l = ( + rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ + ) - links_state.root_COM[i_l, i_b].fill(0.0) + links_state.root_COM_bw[i_l, i_b].fill(0.0) links_state.mass_sum[i_l, i_b] = 0.0 - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l_ in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(links_info.root_idx.shape[0]) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_links)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(links_info.root_idx.shape[0])) + ) + ): + if i_l_ < ( + rigid_global_info.n_awake_links[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else links_info.root_idx.shape[0] + ): + i_l = ( + rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ + ) I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l mass = links_info.inertial_mass[I_l] + links_state.mass_shift[i_l, i_b] ( - links_state.i_pos[i_l, i_b], + links_state.i_pos_bw[i_l, i_b], links_state.i_quat[i_l, i_b], ) = gu.ti_transform_pos_quat_by_trans_quat( links_info.inertial_pos[I_l] + links_state.i_pos_shift[i_l, i_b], @@ -3999,32 +4933,95 @@ def func_COM_links( i_r = links_info.root_idx[I_l] links_state.mass_sum[i_r, i_b] += mass - links_state.root_COM[i_r, i_b] += mass * links_state.i_pos[i_l, i_b] + links_state.root_COM_bw[i_r, i_b] += mass * links_state.i_pos_bw[i_l, i_b] - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l_ in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(links_info.root_idx.shape[0]) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_links)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(links_info.root_idx.shape[0])) + ) + ): + if i_l_ < ( + rigid_global_info.n_awake_links[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else links_info.root_idx.shape[0] + ): + i_l = ( + rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ + ) I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l i_r = links_info.root_idx[I_l] if i_l == i_r: - links_state.root_COM[i_l, i_b] = links_state.root_COM[i_l, i_b] / links_state.mass_sum[i_l, i_b] + links_state.root_COM[i_l, i_b] = links_state.root_COM_bw[i_l, i_b] / links_state.mass_sum[i_l, i_b] - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l_ in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(links_info.root_idx.shape[0]) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_links)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(links_info.root_idx.shape[0])) + ) + ): + if i_l_ < ( + rigid_global_info.n_awake_links[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else links_info.root_idx.shape[0] + ): + i_l = ( + rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ + ) I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l i_r = links_info.root_idx[I_l] links_state.root_COM[i_l, i_b] = links_state.root_COM[i_r, i_b] - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l_ in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(links_info.root_idx.shape[0]) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_links)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(links_info.root_idx.shape[0])) + ) + ): + if i_l_ < ( + rigid_global_info.n_awake_links[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else links_info.root_idx.shape[0] + ): + i_l = ( + rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ + ) I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l i_r = links_info.root_idx[I_l] - links_state.i_pos[i_l, i_b] = links_state.i_pos[i_l, i_b] - links_state.root_COM[i_l, i_b] + links_state.i_pos[i_l, i_b] = links_state.i_pos_bw[i_l, i_b] - links_state.root_COM[i_l, i_b] i_inertial = links_info.inertial_i[I_l] i_mass = links_info.inertial_mass[I_l] + links_state.mass_shift[i_l, i_b] @@ -4034,295 +5031,221 @@ def func_COM_links( links_state.cinr_quat[i_l, i_b], links_state.cinr_mass[i_l, i_b], ) = gu.ti_transform_inertia_by_trans_quat( - i_inertial, i_mass, links_state.i_pos[i_l, i_b], links_state.i_quat[i_l, i_b], EPS + i_inertial, + i_mass, + links_state.i_pos[i_l, i_b], + links_state.i_quat[i_l, i_b], + rigid_global_info.EPS[None], ) - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l_ in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(links_info.root_idx.shape[0]) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_links)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(links_info.root_idx.shape[0])) + ) + ): + if i_l_ < ( + rigid_global_info.n_awake_links[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else links_info.root_idx.shape[0] + ): + i_l = ( + rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ + ) I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - if links_info.n_dofs[I_l] == 0: - continue - i_p = links_info.parent_idx[I_l] - - _i_j = links_info.joint_start[I_l] - _I_j = [_i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else _i_j - joint_type = joints_info.type[_I_j] - - p_pos = ti.Vector.zero(gs.ti_float, 3) - p_quat = gu.ti_identity_quat() - if i_p != -1: - p_pos = links_state.pos[i_p, i_b] - p_quat = links_state.quat[i_p, i_b] + if links_info.n_dofs[I_l] > 0: + i_p = links_info.parent_idx[I_l] - if joint_type == gs.JOINT_TYPE.FREE or (links_info.is_fixed[I_l] and i_p == -1): - links_state.j_pos[i_l, i_b] = links_state.pos[i_l, i_b] - links_state.j_quat[i_l, i_b] = links_state.quat[i_l, i_b] - else: - ( - links_state.j_pos[i_l, i_b], - links_state.j_quat[i_l, i_b], - ) = gu.ti_transform_pos_quat_by_trans_quat(links_info.pos[I_l], links_info.quat[I_l], p_pos, p_quat) + _i_j = links_info.joint_start[I_l] + _I_j = [_i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else _i_j + joint_type = joints_info.type[_I_j] - for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]): - I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j + p_pos = ti.Vector.zero(gs.ti_float, 3) + p_quat = gu.ti_identity_quat() + if i_p != -1: + p_pos = links_state.pos[i_p, i_b] + p_quat = links_state.quat[i_p, i_b] + if joint_type == gs.JOINT_TYPE.FREE or (links_info.is_fixed[I_l] and i_p == -1): + links_state.j_pos[i_l, i_b] = links_state.pos[i_l, i_b] + links_state.j_quat[i_l, i_b] = links_state.quat[i_l, i_b] + else: ( - links_state.j_pos[i_l, i_b], - links_state.j_quat[i_l, i_b], - ) = gu.ti_transform_pos_quat_by_trans_quat( - joints_info.pos[I_j], - gu.ti_identity_quat(), - links_state.j_pos[i_l, i_b], - links_state.j_quat[i_l, i_b], - ) + links_state.j_pos_bw[i_l, 0, i_b], + links_state.j_quat_bw[i_l, 0, i_b], + ) = gu.ti_transform_pos_quat_by_trans_quat(links_info.pos[I_l], links_info.quat[I_l], p_pos, p_quat) + + n_joints = links_info.joint_end[I_l] - links_info.joint_start[I_l] + + for i_j_ in ( + range(n_joints) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_joints_per_link)) + ): + i_j = i_j_ + links_info.joint_start[I_l] + + curr_i_j = 0 if ti.static(not is_backward) else i_j_ + next_i_j = 0 if ti.static(not is_backward) else i_j_ + 1 + + if i_j < links_info.joint_end[I_l]: + I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j + + ( + links_state.j_pos_bw[i_l, next_i_j, i_b], + links_state.j_quat_bw[i_l, next_i_j, i_b], + ) = gu.ti_transform_pos_quat_by_trans_quat( + joints_info.pos[I_j], + gu.ti_identity_quat(), + links_state.j_pos_bw[i_l, curr_i_j, i_b], + links_state.j_quat_bw[i_l, curr_i_j, i_b], + ) - # cdof_fn - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] + i_j_ = 0 if ti.static(not is_backward) else n_joints + links_state.j_pos[i_l, i_b] = links_state.j_pos_bw[i_l, i_j_, i_b] + links_state.j_quat[i_l, i_b] = links_state.j_quat_bw[i_l, i_j_, i_b] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l_ in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(links_info.root_idx.shape[0]) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_links)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(links_info.root_idx.shape[0])) + ) + ): + if i_l_ < ( + rigid_global_info.n_awake_links[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else links_info.root_idx.shape[0] + ): + i_l = ( + rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ + ) I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - if links_info.n_dofs[I_l] == 0: - continue - i_j = links_info.joint_start[I_l] - I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j - joint_type = joints_info.type[I_j] + if links_info.n_dofs[I_l] > 0: + for i_j_ in ( + range(links_info.joint_start[I_l], links_info.joint_end[I_l]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_joints_per_link)) + ): + i_j = i_j_ if ti.static(not is_backward) else (i_j_ + links_info.joint_start[I_l]) + + if i_j < links_info.joint_end[I_l]: + offset_pos = links_state.root_COM[i_l, i_b] - joints_state.xanchor[i_j, i_b] + I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j + joint_type = joints_info.type[I_j] + + dof_start = joints_info.dof_start[I_j] + + EPS = rigid_global_info.EPS[None] + if joint_type == gs.JOINT_TYPE.REVOLUTE: + dofs_state.cdof_ang[dof_start, i_b] = joints_state.xaxis[i_j, i_b] + dofs_state.cdof_vel[dof_start, i_b] = joints_state.xaxis[i_j, i_b].cross(offset_pos) + elif joint_type == gs.JOINT_TYPE.PRISMATIC: + dofs_state.cdof_ang[dof_start, i_b] = ti.Vector.zero(gs.ti_float, 3) + dofs_state.cdof_vel[dof_start, i_b] = joints_state.xaxis[i_j, i_b] + elif joint_type == gs.JOINT_TYPE.SPHERICAL: + xmat_T = gu.ti_quat_to_R(links_state.quat[i_l, i_b], EPS).transpose() + for i in ti.static(range(3)): + dofs_state.cdof_ang[i + dof_start, i_b] = xmat_T[i, :] + dofs_state.cdof_vel[i + dof_start, i_b] = xmat_T[i, :].cross(offset_pos) + elif joint_type == gs.JOINT_TYPE.FREE: + for i in ti.static(range(3)): + dofs_state.cdof_ang[i + dof_start, i_b] = ti.Vector.zero(gs.ti_float, 3) + dofs_state.cdof_vel[i + dof_start, i_b] = ti.Vector.zero(gs.ti_float, 3) + dofs_state.cdof_vel[i + dof_start, i_b][i] = 1.0 + + xmat_T = gu.ti_quat_to_R(links_state.quat[i_l, i_b], EPS).transpose() + for i in ti.static(range(3)): + dofs_state.cdof_ang[i + dof_start + 3, i_b] = xmat_T[i, :] + dofs_state.cdof_vel[i + dof_start + 3, i_b] = xmat_T[i, :].cross(offset_pos) + + for i_d_ in ( + range(dof_start, joints_info.dof_end[I_j]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_joint)) + ): + i_d = i_d_ if ti.static(not is_backward) else (i_d_ + dof_start) + if i_d < joints_info.dof_end[I_j]: + dofs_state.cdofvel_ang[i_d, i_b] = ( + dofs_state.cdof_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] + ) + dofs_state.cdofvel_vel[i_d, i_b] = ( + dofs_state.cdof_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] + ) - if joint_type == gs.JOINT_TYPE.FREE: - for i_d in range(links_info.dof_start[I_l], links_info.dof_end[I_l]): - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - dofs_state.cdof_vel[i_d, i_b] = dofs_info.motion_vel[I_d] - dofs_state.cdof_ang[i_d, i_b] = gu.ti_transform_by_quat( - dofs_info.motion_ang[I_d], links_state.j_quat[i_l, i_b] - ) - offset_pos = links_state.root_COM[i_l, i_b] - links_state.j_pos[i_l, i_b] - ( - dofs_state.cdof_ang[i_d, i_b], - dofs_state.cdof_vel[i_d, i_b], - ) = gu.ti_transform_motion_by_trans_quat( - dofs_state.cdof_ang[i_d, i_b], - dofs_state.cdof_vel[i_d, i_b], - offset_pos, - gu.ti_identity_quat(), - ) +@ti.func +def func_forward_kinematics( + i_b, + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + joints_state: array_class.JointsState, + joints_info: array_class.JointsInfo, + dofs_state: array_class.DofsState, + dofs_info: array_class.DofsInfo, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), + is_backward: ti.template(), +): + for i_e_ in ( + ( + range(rigid_global_info.n_awake_entities[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(entities_info.n_links.shape[0]) + ) + if ti.static(not is_backward) + else ( + ti.static(range(static_rigid_sim_config.max_n_awake_entities)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(entities_info.n_links.shape[0])) + ) + ): + if i_e_ < ( + rigid_global_info.n_awake_entities[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else entities_info.n_links.shape[0] + ): + i_e = ( + rigid_global_info.awake_entities[i_e_, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_e_ + ) - dofs_state.cdofvel_ang[i_d, i_b] = dofs_state.cdof_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] - dofs_state.cdofvel_vel[i_d, i_b] = dofs_state.cdof_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] - - elif joint_type == gs.JOINT_TYPE.FIXED: - pass - else: - for i_d in range(links_info.dof_start[I_l], links_info.dof_end[I_l]): - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - motion_vel = dofs_info.motion_vel[I_d] - motion_ang = dofs_info.motion_ang[I_d] - - dofs_state.cdof_ang[i_d, i_b] = gu.ti_transform_by_quat(motion_ang, links_state.j_quat[i_l, i_b]) - dofs_state.cdof_vel[i_d, i_b] = gu.ti_transform_by_quat(motion_vel, links_state.j_quat[i_l, i_b]) - - offset_pos = links_state.root_COM[i_l, i_b] - links_state.j_pos[i_l, i_b] - ( - dofs_state.cdof_ang[i_d, i_b], - dofs_state.cdof_vel[i_d, i_b], - ) = gu.ti_transform_motion_by_trans_quat( - dofs_state.cdof_ang[i_d, i_b], - dofs_state.cdof_vel[i_d, i_b], - offset_pos, - gu.ti_identity_quat(), - ) - - dofs_state.cdofvel_ang[i_d, i_b] = dofs_state.cdof_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] - dofs_state.cdofvel_vel[i_d, i_b] = dofs_state.cdof_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l in range(n_links): - links_state.root_COM[i_l, i_b].fill(0.0) - links_state.mass_sum[i_l, i_b] = 0.0 - - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l in range(n_links): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - - mass = links_info.inertial_mass[I_l] + links_state.mass_shift[i_l, i_b] - ( - links_state.i_pos[i_l, i_b], - links_state.i_quat[i_l, i_b], - ) = gu.ti_transform_pos_quat_by_trans_quat( - links_info.inertial_pos[I_l] + links_state.i_pos_shift[i_l, i_b], - links_info.inertial_quat[I_l], - links_state.pos[i_l, i_b], - links_state.quat[i_l, i_b], - ) - - i_r = links_info.root_idx[I_l] - links_state.mass_sum[i_r, i_b] += mass - links_state.root_COM[i_r, i_b] += mass * links_state.i_pos[i_l, i_b] - - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l in range(n_links): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - - i_r = links_info.root_idx[I_l] - if i_l == i_r: - if links_state.mass_sum[i_l, i_b] > 0.0: - links_state.root_COM[i_l, i_b] = links_state.root_COM[i_l, i_b] / links_state.mass_sum[i_l, i_b] - - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l in range(n_links): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - - i_r = links_info.root_idx[I_l] - links_state.root_COM[i_l, i_b] = links_state.root_COM[i_r, i_b] - - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l in range(n_links): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - - i_r = links_info.root_idx[I_l] - links_state.i_pos[i_l, i_b] = links_state.i_pos[i_l, i_b] - links_state.root_COM[i_l, i_b] - - i_inertial = links_info.inertial_i[I_l] - i_mass = links_info.inertial_mass[I_l] + links_state.mass_shift[i_l, i_b] - ( - links_state.cinr_inertial[i_l, i_b], - links_state.cinr_pos[i_l, i_b], - links_state.cinr_quat[i_l, i_b], - links_state.cinr_mass[i_l, i_b], - ) = gu.ti_transform_inertia_by_trans_quat( - i_inertial, i_mass, links_state.i_pos[i_l, i_b], links_state.i_quat[i_l, i_b], EPS - ) - - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l in range(n_links): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - if links_info.n_dofs[I_l] == 0: - continue - - i_p = links_info.parent_idx[I_l] - - _i_j = links_info.joint_start[I_l] - _I_j = [_i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else _i_j - joint_type = joints_info.type[_I_j] - - p_pos = ti.Vector.zero(gs.ti_float, 3) - p_quat = gu.ti_identity_quat() - if i_p != -1: - p_pos = links_state.pos[i_p, i_b] - p_quat = links_state.quat[i_p, i_b] - - if joint_type == gs.JOINT_TYPE.FREE or (links_info.is_fixed[I_l] and i_p == -1): - links_state.j_pos[i_l, i_b] = links_state.pos[i_l, i_b] - links_state.j_quat[i_l, i_b] = links_state.quat[i_l, i_b] - else: - ( - links_state.j_pos[i_l, i_b], - links_state.j_quat[i_l, i_b], - ) = gu.ti_transform_pos_quat_by_trans_quat(links_info.pos[I_l], links_info.quat[I_l], p_pos, p_quat) - - for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]): - I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j - - ( - links_state.j_pos[i_l, i_b], - links_state.j_quat[i_l, i_b], - ) = gu.ti_transform_pos_quat_by_trans_quat( - joints_info.pos[I_j], - gu.ti_identity_quat(), - links_state.j_pos[i_l, i_b], - links_state.j_quat[i_l, i_b], - ) - - # cdof_fn - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l in range(n_links): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - if links_info.n_dofs[I_l] == 0: - continue - - for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]): - offset_pos = links_state.root_COM[i_l, i_b] - joints_state.xanchor[i_j, i_b] - I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j - joint_type = joints_info.type[I_j] - - dof_start = joints_info.dof_start[I_j] - - if joint_type == gs.JOINT_TYPE.REVOLUTE: - dofs_state.cdof_ang[dof_start, i_b] = joints_state.xaxis[i_j, i_b] - dofs_state.cdof_vel[dof_start, i_b] = joints_state.xaxis[i_j, i_b].cross(offset_pos) - elif joint_type == gs.JOINT_TYPE.PRISMATIC: - dofs_state.cdof_ang[dof_start, i_b] = ti.Vector.zero(gs.ti_float, 3) - dofs_state.cdof_vel[dof_start, i_b] = joints_state.xaxis[i_j, i_b] - elif joint_type == gs.JOINT_TYPE.SPHERICAL: - xmat_T = gu.ti_quat_to_R(links_state.quat[i_l, i_b], EPS).transpose() - for j in ti.static(range(3)): - dofs_state.cdof_ang[j + dof_start, i_b] = xmat_T[j, :] - dofs_state.cdof_vel[j + dof_start, i_b] = xmat_T[j, :].cross(offset_pos) - elif joint_type == gs.JOINT_TYPE.FREE: - for j in ti.static(range(3)): - dofs_state.cdof_ang[j + dof_start, i_b] = ti.Vector.zero(gs.ti_float, 3) - dofs_state.cdof_vel[j + dof_start, i_b] = ti.Vector.zero(gs.ti_float, 3) - dofs_state.cdof_vel[j + dof_start, i_b][j] = 1.0 - - xmat_T = gu.ti_quat_to_R(links_state.quat[i_l, i_b], EPS).transpose() - for j in ti.static(range(3)): - dofs_state.cdof_ang[j + dof_start + 3, i_b] = xmat_T[j, :] - dofs_state.cdof_vel[j + dof_start + 3, i_b] = xmat_T[j, :].cross(offset_pos) - - for i_d in range(dof_start, joints_info.dof_end[I_j]): - dofs_state.cdofvel_ang[i_d, i_b] = dofs_state.cdof_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] - dofs_state.cdofvel_vel[i_d, i_b] = dofs_state.cdof_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] - - -@ti.func -def func_forward_kinematics( - i_b, - links_state: array_class.LinksState, - links_info: array_class.LinksInfo, - joints_state: array_class.JointsState, - joints_info: array_class.JointsInfo, - dofs_state: array_class.DofsState, - dofs_info: array_class.DofsInfo, - entities_info: array_class.EntitiesInfo, - rigid_global_info: array_class.RigidGlobalInfo, - static_rigid_sim_config: ti.template(), -): - n_entities = entities_info.n_links.shape[0] - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): - i_e = rigid_global_info.awake_entities[i_e_, i_b] - func_forward_kinematics_entity( - i_e, - i_b, - links_state, - links_info, - joints_state, - joints_info, - dofs_state, - dofs_info, - entities_info, - rigid_global_info, - static_rigid_sim_config, - ) - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_e in range(n_entities): - func_forward_kinematics_entity( - i_e, - i_b, - links_state, - links_info, - joints_state, - joints_info, - dofs_state, - dofs_info, - entities_info, - rigid_global_info, - static_rigid_sim_config, - ) + func_forward_kinematics_entity( + i_e, + i_b, + links_state, + links_info, + joints_state, + joints_info, + dofs_state, + dofs_info, + entities_info, + rigid_global_info, + static_rigid_sim_config, + is_backward, + ) @ti.func @@ -4335,37 +5258,39 @@ def func_forward_velocity( dofs_state: array_class.DofsState, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): n_entities = entities_info.n_links.shape[0] - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): - i_e = rigid_global_info.awake_entities[i_e_, i_b] - func_forward_velocity_entity( - i_e=i_e, - i_b=i_b, - entities_info=entities_info, - links_info=links_info, - links_state=links_state, - joints_info=joints_info, - dofs_state=dofs_state, - rigid_global_info=rigid_global_info, - static_rigid_sim_config=static_rigid_sim_config, - ) - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_e in range(n_entities): - func_forward_velocity_entity( - i_e=i_e, - i_b=i_b, - entities_info=entities_info, - links_info=links_info, - links_state=links_state, - joints_info=joints_info, - dofs_state=dofs_state, - rigid_global_info=rigid_global_info, - static_rigid_sim_config=static_rigid_sim_config, - ) + for i_e_ in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_entities[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(n_entities) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_entities)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(entities_info.n_links.shape[0])) + ) + ): + i_e = ( + rigid_global_info.awake_entities[i_e_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_e_ + ) + func_forward_velocity_entity( + i_e=i_e, + i_b=i_b, + entities_info=entities_info, + links_info=links_info, + links_state=links_state, + joints_info=joints_info, + dofs_state=dofs_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, + ) @ti.kernel(fastcache=gs.use_fastcache) @@ -4381,6 +5306,7 @@ def kernel_forward_kinematics_entity( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): for i_b_ in range(envs_idx.shape[0]): i_b = envs_idx[i_b_] @@ -4397,6 +5323,7 @@ def kernel_forward_kinematics_entity( entities_info, rigid_global_info, static_rigid_sim_config, + is_backward, ) @@ -4413,106 +5340,143 @@ def func_forward_kinematics_entity( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): - EPS = rigid_global_info.EPS[None] + # Becomes static loop in backward pass, because we assume this loop is an inner loop + for i_l_ in ( + range(entities_info.link_start[i_e], entities_info.link_end[i_e]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) + ): + EPS = rigid_global_info.EPS[None] + i_l = i_l_ if ti.static(not is_backward) else (i_l_ + entities_info.link_start[i_e]) - for i_l in range(entities_info.link_start[i_e], entities_info.link_end[i_e]): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + if i_l < entities_info.link_end[i_e]: + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - pos = links_info.pos[I_l] - quat = links_info.quat[I_l] - if links_info.parent_idx[I_l] != -1: - parent_pos = links_state.pos[links_info.parent_idx[I_l], i_b] - parent_quat = links_state.quat[links_info.parent_idx[I_l], i_b] - pos = parent_pos + gu.ti_transform_by_quat(pos, parent_quat) - quat = gu.ti_transform_quat_by_quat(quat, parent_quat) + links_state.pos_bw[i_l, 0, i_b] = links_info.pos[I_l] + links_state.quat_bw[i_l, 0, i_b] = links_info.quat[I_l] + if links_info.parent_idx[I_l] != -1: + parent_pos = links_state.pos[links_info.parent_idx[I_l], i_b] + parent_quat = links_state.quat[links_info.parent_idx[I_l], i_b] + links_state.pos_bw[i_l, 0, i_b] = parent_pos + gu.ti_transform_by_quat(links_info.pos[I_l], parent_quat) + links_state.quat_bw[i_l, 0, i_b] = gu.ti_transform_quat_by_quat(links_info.quat[I_l], parent_quat) - for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]): - I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j - joint_type = joints_info.type[I_j] - q_start = joints_info.q_start[I_j] - dof_start = joints_info.dof_start[I_j] - I_d = [dof_start, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else dof_start + n_joints = links_info.joint_end[I_l] - links_info.joint_start[I_l] - # compute axis and anchor - if joint_type == gs.JOINT_TYPE.FREE: - joints_state.xanchor[i_j, i_b] = ti.Vector( - [ - rigid_global_info.qpos[q_start, i_b], - rigid_global_info.qpos[q_start + 1, i_b], - rigid_global_info.qpos[q_start + 2, i_b], - ] - ) - joints_state.xaxis[i_j, i_b] = ti.Vector([0.0, 0.0, 1.0]) - elif joint_type == gs.JOINT_TYPE.FIXED: - pass - else: - axis = ti.Vector([0.0, 0.0, 1.0], dt=gs.ti_float) - if joint_type == gs.JOINT_TYPE.REVOLUTE: - axis = dofs_info.motion_ang[I_d] - elif joint_type == gs.JOINT_TYPE.PRISMATIC: - axis = dofs_info.motion_vel[I_d] + for i_j_ in ( + range(n_joints) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_joints_per_link)) + ): + i_j = i_j_ + links_info.joint_start[I_l] - joints_state.xanchor[i_j, i_b] = gu.ti_transform_by_quat(joints_info.pos[I_j], quat) + pos - joints_state.xaxis[i_j, i_b] = gu.ti_transform_by_quat(axis, quat) + curr_i_j = 0 if ti.static(not is_backward) else i_j_ + next_i_j = 0 if ti.static(not is_backward) else i_j_ + 1 - if joint_type == gs.JOINT_TYPE.FREE: - pos = ti.Vector( - [ - rigid_global_info.qpos[q_start, i_b], - rigid_global_info.qpos[q_start + 1, i_b], - rigid_global_info.qpos[q_start + 2, i_b], - ], - dt=gs.ti_float, - ) - quat = ti.Vector( - [ - rigid_global_info.qpos[q_start + 3, i_b], - rigid_global_info.qpos[q_start + 4, i_b], - rigid_global_info.qpos[q_start + 5, i_b], - rigid_global_info.qpos[q_start + 6, i_b], - ], - dt=gs.ti_float, - ) - xyz = gu.ti_quat_to_xyz(quat, EPS) - for j in ti.static(range(3)): - dofs_state.pos[dof_start + j, i_b] = pos[j] - dofs_state.pos[dof_start + 3 + j, i_b] = xyz[j] - elif joint_type == gs.JOINT_TYPE.FIXED: - pass - elif joint_type == gs.JOINT_TYPE.SPHERICAL: - qloc = ti.Vector( - [ - rigid_global_info.qpos[q_start, i_b], - rigid_global_info.qpos[q_start + 1, i_b], - rigid_global_info.qpos[q_start + 2, i_b], - rigid_global_info.qpos[q_start + 3, i_b], - ], - dt=gs.ti_float, - ) - xyz = gu.ti_quat_to_xyz(qloc, EPS) - for j in ti.static(range(3)): - dofs_state.pos[dof_start + j, i_b] = xyz[j] - quat = gu.ti_transform_quat_by_quat(qloc, quat) - pos = joints_state.xanchor[i_j, i_b] - gu.ti_transform_by_quat(joints_info.pos[I_j], quat) - elif joint_type == gs.JOINT_TYPE.REVOLUTE: - axis = dofs_info.motion_ang[I_d] - dofs_state.pos[dof_start, i_b] = ( - rigid_global_info.qpos[q_start, i_b] - rigid_global_info.qpos0[q_start, i_b] - ) - qloc = gu.ti_rotvec_to_quat(axis * dofs_state.pos[dof_start, i_b], EPS) - quat = gu.ti_transform_quat_by_quat(qloc, quat) - pos = joints_state.xanchor[i_j, i_b] - gu.ti_transform_by_quat(joints_info.pos[I_j], quat) - else: # joint_type == gs.JOINT_TYPE.PRISMATIC: - dofs_state.pos[dof_start, i_b] = ( - rigid_global_info.qpos[q_start, i_b] - rigid_global_info.qpos0[q_start, i_b] - ) - pos = pos + joints_state.xaxis[i_j, i_b] * dofs_state.pos[dof_start, i_b] + if i_j < links_info.joint_end[I_l]: + I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j + joint_type = joints_info.type[I_j] + q_start = joints_info.q_start[I_j] + dof_start = joints_info.dof_start[I_j] + I_d = [dof_start, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else dof_start + + # compute axis and anchor + if joint_type == gs.JOINT_TYPE.FREE: + joints_state.xanchor[i_j, i_b] = ti.Vector( + [ + rigid_global_info.qpos[q_start, i_b], + rigid_global_info.qpos[q_start + 1, i_b], + rigid_global_info.qpos[q_start + 2, i_b], + ] + ) + joints_state.xaxis[i_j, i_b] = ti.Vector([0.0, 0.0, 1.0]) + elif joint_type == gs.JOINT_TYPE.FIXED: + pass + else: + axis = ti.Vector([0.0, 0.0, 1.0], dt=gs.ti_float) + if joint_type == gs.JOINT_TYPE.REVOLUTE: + axis = dofs_info.motion_ang[I_d] + elif joint_type == gs.JOINT_TYPE.PRISMATIC: + axis = dofs_info.motion_vel[I_d] + + joints_state.xanchor[i_j, i_b] = ( + gu.ti_transform_by_quat(joints_info.pos[I_j], links_state.quat_bw[i_l, curr_i_j, i_b]) + + links_state.pos_bw[i_l, curr_i_j, i_b] + ) + joints_state.xaxis[i_j, i_b] = gu.ti_transform_by_quat( + axis, links_state.quat_bw[i_l, curr_i_j, i_b] + ) + + if joint_type == gs.JOINT_TYPE.FREE: + links_state.pos_bw[i_l, next_i_j, i_b] = ti.Vector( + [ + rigid_global_info.qpos[q_start, i_b], + rigid_global_info.qpos[q_start + 1, i_b], + rigid_global_info.qpos[q_start + 2, i_b], + ], + dt=gs.ti_float, + ) + links_state.quat_bw[i_l, next_i_j, i_b] = ti.Vector( + [ + rigid_global_info.qpos[q_start + 3, i_b], + rigid_global_info.qpos[q_start + 4, i_b], + rigid_global_info.qpos[q_start + 5, i_b], + rigid_global_info.qpos[q_start + 6, i_b], + ], + dt=gs.ti_float, + ) + xyz = gu.ti_quat_to_xyz(links_state.quat_bw[i_l, next_i_j, i_b], EPS) + for j in ti.static(range(3)): + dofs_state.pos[dof_start + j, i_b] = links_state.pos_bw[i_l, next_i_j, i_b][j] + dofs_state.pos[dof_start + 3 + j, i_b] = xyz[j] + elif joint_type == gs.JOINT_TYPE.FIXED: + pass + elif joint_type == gs.JOINT_TYPE.SPHERICAL: + qloc = ti.Vector( + [ + rigid_global_info.qpos[q_start, i_b], + rigid_global_info.qpos[q_start + 1, i_b], + rigid_global_info.qpos[q_start + 2, i_b], + rigid_global_info.qpos[q_start + 3, i_b], + ], + dt=gs.ti_float, + ) + xyz = gu.ti_quat_to_xyz(qloc, EPS) + for j in ti.static(range(3)): + dofs_state.pos[dof_start + j, i_b] = xyz[j] + links_state.quat_bw[i_l, next_i_j, i_b] = gu.ti_transform_quat_by_quat( + qloc, links_state.quat_bw[i_l, curr_i_j, i_b] + ) + links_state.pos_bw[i_l, next_i_j, i_b] = joints_state.xanchor[ + i_j, i_b + ] - gu.ti_transform_by_quat(joints_info.pos[I_j], links_state.quat_bw[i_l, next_i_j, i_b]) + elif joint_type == gs.JOINT_TYPE.REVOLUTE: + axis = dofs_info.motion_ang[I_d] + dofs_state.pos[dof_start, i_b] = ( + rigid_global_info.qpos[q_start, i_b] - rigid_global_info.qpos0[q_start, i_b] + ) + qloc = gu.ti_rotvec_to_quat(axis * dofs_state.pos[dof_start, i_b], EPS) + links_state.quat_bw[i_l, next_i_j, i_b] = gu.ti_transform_quat_by_quat( + qloc, links_state.quat_bw[i_l, curr_i_j, i_b] + ) + links_state.pos_bw[i_l, next_i_j, i_b] = joints_state.xanchor[ + i_j, i_b + ] - gu.ti_transform_by_quat(joints_info.pos[I_j], links_state.quat_bw[i_l, next_i_j, i_b]) + else: # joint_type == gs.JOINT_TYPE.PRISMATIC: + dofs_state.pos[dof_start, i_b] = ( + rigid_global_info.qpos[q_start, i_b] - rigid_global_info.qpos0[q_start, i_b] + ) + links_state.pos_bw[i_l, next_i_j, i_b] = ( + links_state.pos_bw[i_l, curr_i_j, i_b] + + joints_state.xaxis[i_j, i_b] * dofs_state.pos[dof_start, i_b] + ) - # Skip link pose update for fixed root links to let users manually overwrite them - if not (links_info.parent_idx[I_l] == -1 and links_info.is_fixed[I_l]): - links_state.pos[i_l, i_b] = pos - links_state.quat[i_l, i_b] = quat + # Skip link pose update for fixed root links to let users manually overwrite them + i_j_ = 0 if ti.static(not is_backward) else n_joints + if not (links_info.parent_idx[I_l] == -1 and links_info.is_fixed[I_l]): + links_state.pos[i_l, i_b] = links_state.pos_bw[i_l, i_j_, i_b] + links_state.quat[i_l, i_b] = links_state.quat_bw[i_l, i_j_, i_b] @ti.func @@ -4526,71 +5490,113 @@ def func_forward_velocity_entity( dofs_state: array_class.DofsState, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): - for i_l in range(entities_info.link_start[i_e], entities_info.link_end[i_e]): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + for i_l_ in ( + range(entities_info.link_start[i_e], entities_info.link_end[i_e]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) + ): + i_l = i_l_ if ti.static(not is_backward) else (i_l_ + entities_info.link_start[i_e]) - cvel_vel = ti.Vector.zero(gs.ti_float, 3) - cvel_ang = ti.Vector.zero(gs.ti_float, 3) - if links_info.parent_idx[I_l] != -1: - cvel_vel = links_state.cd_vel[links_info.parent_idx[I_l], i_b] - cvel_ang = links_state.cd_ang[links_info.parent_idx[I_l], i_b] + if i_l < entities_info.link_end[i_e]: + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + n_joints = links_info.joint_end[I_l] - links_info.joint_start[I_l] - for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]): - I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j - joint_type = joints_info.type[I_j] - q_start = joints_info.q_start[I_j] - dof_start = joints_info.dof_start[I_j] + links_state.cd_vel_bw[i_l, 0, i_b] = ti.Vector.zero(gs.ti_float, 3) + links_state.cd_ang_bw[i_l, 0, i_b] = ti.Vector.zero(gs.ti_float, 3) - if joint_type == gs.JOINT_TYPE.FREE: - for i_3 in ti.static(range(3)): - cvel_vel = ( - cvel_vel + dofs_state.cdof_vel[dof_start + i_3, i_b] * dofs_state.vel[dof_start + i_3, i_b] - ) - cvel_ang = ( - cvel_ang + dofs_state.cdof_ang[dof_start + i_3, i_b] * dofs_state.vel[dof_start + i_3, i_b] - ) + if links_info.parent_idx[I_l] != -1: + links_state.cd_vel_bw[i_l, 0, i_b] = links_state.cd_vel[links_info.parent_idx[I_l], i_b] + links_state.cd_ang_bw[i_l, 0, i_b] = links_state.cd_ang[links_info.parent_idx[I_l], i_b] - for i_3 in ti.static(range(3)): - ( - dofs_state.cdofd_ang[dof_start + i_3, i_b], - dofs_state.cdofd_vel[dof_start + i_3, i_b], - ) = ti.Vector.zero(gs.ti_float, 3), ti.Vector.zero(gs.ti_float, 3) + for i_j_ in ( + range(n_joints) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_joints_per_link)) + ): + i_j = i_j_ + links_info.joint_start[I_l] - ( - dofs_state.cdofd_ang[dof_start + i_3 + 3, i_b], - dofs_state.cdofd_vel[dof_start + i_3 + 3, i_b], - ) = gu.motion_cross_motion( - cvel_ang, - cvel_vel, - dofs_state.cdof_ang[dof_start + i_3 + 3, i_b], - dofs_state.cdof_vel[dof_start + i_3 + 3, i_b], - ) + if i_j < links_info.joint_end[I_l]: + I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j + joint_type = joints_info.type[I_j] + q_start = joints_info.q_start[I_j] + dof_start = joints_info.dof_start[I_j] - for i_3 in ti.static(range(3)): - cvel_vel = ( - cvel_vel - + dofs_state.cdof_vel[dof_start + i_3 + 3, i_b] * dofs_state.vel[dof_start + i_3 + 3, i_b] - ) - cvel_ang = ( - cvel_ang - + dofs_state.cdof_ang[dof_start + i_3 + 3, i_b] * dofs_state.vel[dof_start + i_3 + 3, i_b] - ) + curr_i_j = 0 if ti.static(not is_backward) else i_j_ + next_i_j = 0 if ti.static(not is_backward) else i_j_ + 1 - else: - for i_d in range(dof_start, joints_info.dof_end[I_j]): - dofs_state.cdofd_ang[i_d, i_b], dofs_state.cdofd_vel[i_d, i_b] = gu.motion_cross_motion( - cvel_ang, - cvel_vel, - dofs_state.cdof_ang[i_d, i_b], - dofs_state.cdof_vel[i_d, i_b], - ) - for i_d in range(dof_start, joints_info.dof_end[I_j]): - cvel_vel = cvel_vel + dofs_state.cdof_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] - cvel_ang = cvel_ang + dofs_state.cdof_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] + if joint_type == gs.JOINT_TYPE.FREE: + for i_3 in ti.static(range(3)): + links_state.cd_vel_bw[i_l, curr_i_j, i_b] += ( + dofs_state.cdof_vel[dof_start + i_3, i_b] * dofs_state.vel[dof_start + i_3, i_b] + ) + links_state.cd_ang_bw[i_l, curr_i_j, i_b] += ( + dofs_state.cdof_ang[dof_start + i_3, i_b] * dofs_state.vel[dof_start + i_3, i_b] + ) + + for i_3 in ti.static(range(3)): + ( + dofs_state.cdofd_ang[dof_start + i_3, i_b], + dofs_state.cdofd_vel[dof_start + i_3, i_b], + ) = ti.Vector.zero(gs.ti_float, 3), ti.Vector.zero(gs.ti_float, 3) + + ( + dofs_state.cdofd_ang[dof_start + i_3 + 3, i_b], + dofs_state.cdofd_vel[dof_start + i_3 + 3, i_b], + ) = gu.motion_cross_motion( + links_state.cd_ang_bw[i_l, curr_i_j, i_b], + links_state.cd_vel_bw[i_l, curr_i_j, i_b], + dofs_state.cdof_ang[dof_start + i_3 + 3, i_b], + dofs_state.cdof_vel[dof_start + i_3 + 3, i_b], + ) - links_state.cd_vel[i_l, i_b] = cvel_vel - links_state.cd_ang[i_l, i_b] = cvel_ang + links_state.cd_vel_bw[i_l, next_i_j, i_b] = links_state.cd_vel_bw[i_l, curr_i_j, i_b] + links_state.cd_ang_bw[i_l, next_i_j, i_b] = links_state.cd_ang_bw[i_l, curr_i_j, i_b] + + for i_3 in ti.static(range(3)): + links_state.cd_vel_bw[i_l, next_i_j, i_b] += ( + dofs_state.cdof_vel[dof_start + i_3 + 3, i_b] * dofs_state.vel[dof_start + i_3 + 3, i_b] + ) + links_state.cd_ang_bw[i_l, next_i_j, i_b] += ( + dofs_state.cdof_ang[dof_start + i_3 + 3, i_b] * dofs_state.vel[dof_start + i_3 + 3, i_b] + ) + + else: + for i_d_ in ( + range(dof_start, joints_info.dof_end[I_j]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_joint)) + ): + i_d = i_d_ if ti.static(not is_backward) else (i_d_ + dof_start) + if i_d < joints_info.dof_end[I_j]: + dofs_state.cdofd_ang[i_d, i_b], dofs_state.cdofd_vel[i_d, i_b] = gu.motion_cross_motion( + links_state.cd_ang_bw[i_l, curr_i_j, i_b], + links_state.cd_vel_bw[i_l, curr_i_j, i_b], + dofs_state.cdof_ang[i_d, i_b], + dofs_state.cdof_vel[i_d, i_b], + ) + + links_state.cd_vel_bw[i_l, next_i_j, i_b] = links_state.cd_vel_bw[i_l, curr_i_j, i_b] + links_state.cd_ang_bw[i_l, next_i_j, i_b] = links_state.cd_ang_bw[i_l, curr_i_j, i_b] + + for i_d_ in ( + range(dof_start, joints_info.dof_end[I_j]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_joint)) + ): + i_d = i_d_ if ti.static(not is_backward) else (i_d_ + dof_start) + if i_d < joints_info.dof_end[I_j]: + links_state.cd_vel_bw[i_l, next_i_j, i_b] += ( + dofs_state.cdof_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] + ) + links_state.cd_ang_bw[i_l, next_i_j, i_b] += ( + dofs_state.cdof_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] + ) + + i_j_ = 0 if ti.static(not is_backward) else n_joints + links_state.cd_vel[i_l, i_b] = links_state.cd_vel_bw[i_l, i_j_, i_b] + links_state.cd_ang[i_l, i_b] = links_state.cd_ang_bw[i_l, i_j_, i_b] @ti.kernel(fastcache=gs.use_fastcache) @@ -4603,6 +5609,7 @@ def kernel_update_geoms( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), force_update_fixed_geoms: ti.template(), + is_backward: ti.template(), ): for i_b_ in range(envs_idx.shape[0]): i_b = envs_idx[i_b_] @@ -4616,6 +5623,7 @@ def kernel_update_geoms( rigid_global_info, static_rigid_sim_config, force_update_fixed_geoms, + is_backward, ) @@ -4629,15 +5637,47 @@ def func_update_geoms( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), force_update_fixed_geoms: ti.template(), + is_backward: ti.template(), ): """ NOTE: this only update geom pose, not its verts and else. """ n_geoms = geoms_info.pos.shape[0] - if ti.static(static_rigid_sim_config.use_hibernation): - for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): - i_e = rigid_global_info.awake_entities[i_e_, i_b] - for i_g in range(entities_info.geom_start[i_e], entities_info.geom_end[i_e]): + for i_0 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_entities[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(n_geoms) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_entities)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(geoms_info.pos.shape[0])) + ) + ): + i_e = rigid_global_info.awake_entities[i_0, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 0 + n_geoms = entities_info.geom_end[i_e] - entities_info.geom_start[i_e] + + for i_1 in ( + ( + # Dynamic inner loop for forward pass + range(n_geoms) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_geoms_per_entity)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + ): + i_g = i_1 + entities_info.geom_start[i_e] if ti.static(static_rigid_sim_config.use_hibernation) else i_0 + if i_1 < (n_geoms if ti.static(static_rigid_sim_config.use_hibernation) else 1): if force_update_fixed_geoms or not geoms_info.is_fixed[i_g]: ( geoms_state.pos[i_g, i_b], @@ -4649,19 +5689,6 @@ def func_update_geoms( links_state.quat[geoms_info.link_idx[i_g], i_b], ) geoms_state.verts_updated[i_g, i_b] = False - else: - for i_g in range(n_geoms): - if force_update_fixed_geoms or not geoms_info.is_fixed[i_g]: - ( - geoms_state.pos[i_g, i_b], - geoms_state.quat[i_g, i_b], - ) = gu.ti_transform_pos_quat_by_trans_quat( - geoms_info.pos[i_g], - geoms_info.quat[i_g], - links_state.pos[geoms_info.link_idx[i_g], i_b], - links_state.quat[geoms_info.link_idx[i_g], i_b], - ) - geoms_state.verts_updated[i_g, i_b] = False @ti.kernel(fastcache=gs.use_fastcache) @@ -4853,7 +5880,9 @@ def func_hibernate__for_all_awake_islands_either_hiberanate_or_update_aabb_sort_ ) # store entities in the hibernated islands by daisy chaining them - ci.entity_idx_to_next_entity_idx_in_hibernated_island[prev_entity_idx, i_b] = entity_idx + contact_island_state.entity_idx_to_next_entity_idx_in_hibernated_island[ + prev_entity_idx, i_b + ] = entity_idx prev_entity_idx = entity_idx @@ -5013,16 +6042,16 @@ def func_clear_external_force( _B = links_state.pos.shape[1] n_links = links_state.pos.shape[0] - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) - for i_b in range(_B): - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] - links_state.cfrc_applied_ang[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) - links_state.cfrc_applied_vel[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) - else: - ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) - for i_l, i_b in ti.ndrange(n_links, _B): + ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) + for i_0, i_b in ( + ti.ndrange(1, _B) if ti.static(static_rigid_sim_config.use_hibernation) else ti.ndrange(n_links, _B) + ): + for i_1 in ( + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ): + i_l = rigid_global_info.awake_links[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_0 links_state.cfrc_applied_ang[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) links_state.cfrc_applied_vel[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) @@ -5040,170 +6069,201 @@ def func_torque_and_passive_force( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), contact_island_state: array_class.ContactIslandState, + is_backward: ti.template(), ): - EPS = rigid_global_info.EPS[None] - - n_entities = entities_info.n_links.shape[0] - _B = dofs_state.ctrl_mode.shape[1] - n_dofs = dofs_state.ctrl_mode.shape[0] - n_links = links_info.root_idx.shape[0] - # compute force based on each dof's ctrl mode ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_e, i_b in ti.ndrange(n_entities, _B): + for i_e, i_b in ti.ndrange(entities_info.n_links.shape[0], dofs_state.ctrl_mode.shape[1]): wakeup = False - for i_l in range(entities_info.link_start[i_e], entities_info.link_end[i_e]): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - if links_info.n_dofs[I_l] == 0: - continue + EPS = rigid_global_info.EPS[None] - i_j = links_info.joint_start[I_l] - I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j - joint_type = joints_info.type[I_j] + for i_l_ in ( + range(entities_info.link_start[i_e], entities_info.link_end[i_e]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) + ): + i_l = i_l_ if ti.static(not is_backward) else (i_l_ + entities_info.link_start[i_e]) - for i_d in range(links_info.dof_start[I_l], links_info.dof_end[I_l]): - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - force = gs.ti_float(0.0) - if dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.FORCE: - force = dofs_state.ctrl_force[i_d, i_b] - elif dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.VELOCITY: - force = dofs_info.kv[I_d] * (dofs_state.ctrl_vel[i_d, i_b] - dofs_state.vel[i_d, i_b]) - elif dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.POSITION and not ( - joint_type == gs.JOINT_TYPE.FREE and i_d >= links_info.dof_start[I_l] + 3 - ): - force = dofs_info.kp[I_d] * ( - dofs_state.ctrl_pos[i_d, i_b] - dofs_state.pos[i_d, i_b] - ) + dofs_info.kv[I_d] * (dofs_state.ctrl_vel[i_d, i_b] - dofs_state.vel[i_d, i_b]) - - dofs_state.qf_applied[i_d, i_b] = ti.math.clamp( - force, - dofs_info.force_range[I_d][0], - dofs_info.force_range[I_d][1], - ) + if i_l < entities_info.link_end[i_e]: + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + if links_info.n_dofs[I_l] > 0: + i_j = links_info.joint_start[I_l] + I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j + joint_type = joints_info.type[I_j] + + for i_d_ in ( + range(links_info.dof_start[I_l], links_info.dof_end[I_l]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_link)) + ): + i_d = i_d_ if ti.static(not is_backward) else (i_d_ + links_info.dof_start[I_l]) + + if i_d < links_info.dof_end[I_l]: + I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d + force = gs.ti_float(0.0) + if dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.FORCE: + force = dofs_state.ctrl_force[i_d, i_b] + elif dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.VELOCITY: + force = dofs_info.kv[I_d] * (dofs_state.ctrl_vel[i_d, i_b] - dofs_state.vel[i_d, i_b]) + elif dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.POSITION and not ( + joint_type == gs.JOINT_TYPE.FREE and i_d >= links_info.dof_start[I_l] + 3 + ): + force = dofs_info.kp[I_d] * ( + dofs_state.ctrl_pos[i_d, i_b] - dofs_state.pos[i_d, i_b] + ) + dofs_info.kv[I_d] * (dofs_state.ctrl_vel[i_d, i_b] - dofs_state.vel[i_d, i_b]) + + dofs_state.qf_applied[i_d, i_b] = ti.math.clamp( + force, + dofs_info.force_range[I_d][0], + dofs_info.force_range[I_d][1], + ) - if ti.abs(force) > EPS: - wakeup = True + if ti.abs(force) > EPS: + wakeup = True - dof_start = links_info.dof_start[I_l] - if joint_type == gs.JOINT_TYPE.FREE and ( - dofs_state.ctrl_mode[dof_start + 3, i_b] == gs.CTRL_MODE.POSITION - or dofs_state.ctrl_mode[dof_start + 4, i_b] == gs.CTRL_MODE.POSITION - or dofs_state.ctrl_mode[dof_start + 5, i_b] == gs.CTRL_MODE.POSITION - ): - xyz = ti.Vector( - [ - dofs_state.pos[0 + 3 + dof_start, i_b], - dofs_state.pos[1 + 3 + dof_start, i_b], - dofs_state.pos[2 + 3 + dof_start, i_b], - ], - dt=gs.ti_float, - ) + dof_start = links_info.dof_start[I_l] + if joint_type == gs.JOINT_TYPE.FREE and ( + dofs_state.ctrl_mode[dof_start + 3, i_b] == gs.CTRL_MODE.POSITION + or dofs_state.ctrl_mode[dof_start + 4, i_b] == gs.CTRL_MODE.POSITION + or dofs_state.ctrl_mode[dof_start + 5, i_b] == gs.CTRL_MODE.POSITION + ): + xyz = ti.Vector( + [ + dofs_state.pos[0 + 3 + dof_start, i_b], + dofs_state.pos[1 + 3 + dof_start, i_b], + dofs_state.pos[2 + 3 + dof_start, i_b], + ], + dt=gs.ti_float, + ) - ctrl_xyz = ti.Vector( - [ - dofs_state.ctrl_pos[0 + 3 + dof_start, i_b], - dofs_state.ctrl_pos[1 + 3 + dof_start, i_b], - dofs_state.ctrl_pos[2 + 3 + dof_start, i_b], - ], - dt=gs.ti_float, - ) + ctrl_xyz = ti.Vector( + [ + dofs_state.ctrl_pos[0 + 3 + dof_start, i_b], + dofs_state.ctrl_pos[1 + 3 + dof_start, i_b], + dofs_state.ctrl_pos[2 + 3 + dof_start, i_b], + ], + dt=gs.ti_float, + ) - quat = gu.ti_xyz_to_quat(xyz) - ctrl_quat = gu.ti_xyz_to_quat(ctrl_xyz) + quat = gu.ti_xyz_to_quat(xyz) + ctrl_quat = gu.ti_xyz_to_quat(ctrl_xyz) - q_diff = gu.ti_transform_quat_by_quat(ctrl_quat, gu.ti_inv_quat(quat)) - rotvec = gu.ti_quat_to_rotvec(q_diff, EPS) + q_diff = gu.ti_transform_quat_by_quat(ctrl_quat, gu.ti_inv_quat(quat)) + rotvec = gu.ti_quat_to_rotvec(q_diff, EPS) - for j in ti.static(range(3)): - i_d = dof_start + 3 + j - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - force = dofs_info.kp[I_d] * rotvec[j] - dofs_info.kv[I_d] * dofs_state.vel[i_d, i_b] + for j in ti.static(range(3)): + i_d = dof_start + 3 + j + I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d + force = dofs_info.kp[I_d] * rotvec[j] - dofs_info.kv[I_d] * dofs_state.vel[i_d, i_b] - dofs_state.qf_applied[i_d, i_b] = ti.math.clamp( - force, dofs_info.force_range[I_d][0], dofs_info.force_range[I_d][1] - ) + dofs_state.qf_applied[i_d, i_b] = ti.math.clamp( + force, dofs_info.force_range[I_d][0], dofs_info.force_range[I_d][1] + ) - if ti.abs(force) > EPS: - wakeup = True + if ti.abs(force) > EPS: + wakeup = True + + if ti.static(static_rigid_sim_config.use_hibernation): + if entities_state.hibernated[i_e, i_b] and wakeup: + # TODO: migrate this function + func_wakeup_entity_and_its_temp_island( + i_e, + i_b, + entities_state, + entities_info, + dofs_state, + links_state, + geoms_state, + rigid_global_info, + contact_island_state, + ) - if ti.static(static_rigid_sim_config.use_hibernation) and entities_state.hibernated[i_e, i_b] and wakeup: - func_wakeup_entity_and_its_temp_island( - i_e, - i_b, - entities_state, - entities_info, - dofs_state, - links_state, - geoms_state, - rigid_global_info, - contact_island_state, + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ( + ti.ndrange(1, dofs_state.ctrl_mode.shape[1]) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.ndrange(dofs_state.ctrl_mode.shape[0], dofs_state.ctrl_mode.shape[1]) + ): + for i_1 in ( + ( + # Dynamic inner for forward pass + range(rigid_global_info.n_awake_dofs[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) ) + if ti.static(not is_backward) + else ( + # Static inner for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_dofs)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + ): + if i_1 < (rigid_global_info.n_awake_dofs[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1): + i_d = ( + rigid_global_info.awake_dofs[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_d_ in range(rigid_global_info.n_awake_dofs[i_b]): - i_d = rigid_global_info.awake_dofs[i_d_, i_b] I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - dofs_state.qf_passive[i_d, i_b] = -dofs_info.damping[I_d] * dofs_state.vel[i_d, i_b] - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ( + ti.ndrange(1, dofs_state.ctrl_mode.shape[1]) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.ndrange(links_info.root_idx.shape[0], dofs_state.ctrl_mode.shape[1]) + ): + for i_1 in ( + ( + # Dynamic inner for forward pass + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ) + if ti.static(not is_backward) + else ( + # Static inner for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_links)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + ): + if i_1 < ( + rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): + i_l = ( + rigid_global_info.awake_links[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - if links_info.n_dofs[I_l] == 0: - continue - - i_j = links_info.joint_start[I_l] - I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j - joint_type = joints_info.type[I_j] - - if joint_type != gs.JOINT_TYPE.FREE and joint_type != gs.JOINT_TYPE.FIXED: - q_start = links_info.q_start[I_l] - dof_start = links_info.dof_start[I_l] - dof_end = links_info.dof_end[I_l] - - for j_d in range(dof_end - dof_start): - I_d = ( - [dof_start + j_d, i_b] - if ti.static(static_rigid_sim_config.batch_dofs_info) - else dof_start + j_d - ) - dofs_state.qf_passive[dof_start + j_d, i_b] += ( - -rigid_global_info.qpos[q_start + j_d, i_b] * dofs_info.stiffness[I_d] - ) - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_d, i_b in ti.ndrange(n_dofs, _B): - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - dofs_state.qf_passive[i_d, i_b] = -dofs_info.damping[I_d] * dofs_state.vel[i_d, i_b] - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l, i_b in ti.ndrange(n_links, _B): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - if links_info.n_dofs[I_l] == 0: - continue - - i_j = links_info.joint_start[I_l] - I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j - joint_type = joints_info.type[I_j] + if links_info.n_dofs[I_l] > 0: + i_j = links_info.joint_start[I_l] + I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j + joint_type = joints_info.type[I_j] - if joint_type != gs.JOINT_TYPE.FREE and joint_type != gs.JOINT_TYPE.FIXED: - q_start = links_info.q_start[I_l] - dof_start = links_info.dof_start[I_l] - dof_end = links_info.dof_end[I_l] + if joint_type != gs.JOINT_TYPE.FREE and joint_type != gs.JOINT_TYPE.FIXED: + q_start = links_info.q_start[I_l] + dof_start = links_info.dof_start[I_l] + dof_end = links_info.dof_end[I_l] - for j_d in range(dof_end - dof_start): - I_d = ( - [dof_start + j_d, i_b] - if ti.static(static_rigid_sim_config.batch_dofs_info) - else dof_start + j_d - ) - dofs_state.qf_passive[dof_start + j_d, i_b] += ( - -rigid_global_info.qpos[q_start + j_d, i_b] * dofs_info.stiffness[I_d] - ) + for j_d in ( + range(dof_end - dof_start) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_link)) + ): + if j_d < dof_end: + I_d = ( + [dof_start + j_d, i_b] + if ti.static(static_rigid_sim_config.batch_dofs_info) + else dof_start + j_d + ) + dofs_state.qf_passive[dof_start + j_d, i_b] += ( + -rigid_global_info.qpos[q_start + j_d, i_b] * dofs_info.stiffness[I_d] + ) @ti.func @@ -5215,90 +6275,85 @@ def func_update_acc( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): - _B = dofs_state.ctrl_mode.shape[1] - n_links = links_info.root_idx.shape[0] - n_entities = entities_info.n_links.shape[0] + # Assume this is the outermost loop + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ( + ti.ndrange(1, dofs_state.ctrl_mode.shape[1]) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.ndrange(entities_info.n_links.shape[0], dofs_state.ctrl_mode.shape[1]) + ): + for i_1 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_entities[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_entities)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + ): + if i_1 < ( + rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): + i_e = ( + rigid_global_info.awake_entities[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): - i_e = rigid_global_info.awake_entities[i_e_, i_b] - for i_l in range(entities_info.link_start[i_e], entities_info.link_end[i_e]): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - i_p = links_info.parent_idx[I_l] - - if i_p == -1: - links_state.cdd_vel[i_l, i_b] = -rigid_global_info.gravity[i_b] * ( - 1 - entities_info.gravity_compensation[i_e] - ) - links_state.cdd_ang[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) - if ti.static(update_cacc): - links_state.cacc_lin[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) - links_state.cacc_ang[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) - else: - links_state.cdd_vel[i_l, i_b] = links_state.cdd_vel[i_p, i_b] - links_state.cdd_ang[i_l, i_b] = links_state.cdd_ang[i_p, i_b] - if ti.static(update_cacc): - links_state.cacc_lin[i_l, i_b] = links_state.cacc_lin[i_p, i_b] - links_state.cacc_ang[i_l, i_b] = links_state.cacc_ang[i_p, i_b] - - for i_d in range(links_info.dof_start[I_l], links_info.dof_end[I_l]): - local_cdd_vel = dofs_state.cdofd_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] - local_cdd_ang = dofs_state.cdofd_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] - links_state.cdd_vel[i_l, i_b] = links_state.cdd_vel[i_l, i_b] + local_cdd_vel - links_state.cdd_ang[i_l, i_b] = links_state.cdd_ang[i_l, i_b] + local_cdd_ang - if ti.static(update_cacc): - links_state.cacc_lin[i_l, i_b] = ( - links_state.cacc_lin[i_l, i_b] - + local_cdd_vel - + dofs_state.cdof_vel[i_d, i_b] * dofs_state.acc[i_d, i_b] - ) - links_state.cacc_ang[i_l, i_b] = ( - links_state.cacc_ang[i_l, i_b] - + local_cdd_ang - + dofs_state.cdof_ang[i_d, i_b] * dofs_state.acc[i_d, i_b] - ) - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_e, i_b in ti.ndrange(n_entities, _B): - for i_l in range(entities_info.link_start[i_e], entities_info.link_end[i_e]): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - i_p = links_info.parent_idx[I_l] + for i_l_ in ( + range(entities_info.link_start[i_e], entities_info.link_end[i_e]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) + ): + i_l = i_l_ if ti.static(not is_backward) else (i_l_ + entities_info.link_start[i_e]) - if i_p == -1: - links_state.cdd_vel[i_l, i_b] = -rigid_global_info.gravity[i_b] * ( - 1 - entities_info.gravity_compensation[i_e] - ) - links_state.cdd_ang[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) - if ti.static(update_cacc): - links_state.cacc_lin[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) - links_state.cacc_ang[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) - else: - links_state.cdd_vel[i_l, i_b] = links_state.cdd_vel[i_p, i_b] - links_state.cdd_ang[i_l, i_b] = links_state.cdd_ang[i_p, i_b] - if ti.static(update_cacc): - links_state.cacc_lin[i_l, i_b] = links_state.cacc_lin[i_p, i_b] - links_state.cacc_ang[i_l, i_b] = links_state.cacc_ang[i_p, i_b] - - for i_d in range(links_info.dof_start[I_l], links_info.dof_end[I_l]): - # cacc = cacc_parent + cdofdot * qvel + cdof * qacc - local_cdd_vel = dofs_state.cdofd_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] - local_cdd_ang = dofs_state.cdofd_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] - links_state.cdd_vel[i_l, i_b] = links_state.cdd_vel[i_l, i_b] + local_cdd_vel - links_state.cdd_ang[i_l, i_b] = links_state.cdd_ang[i_l, i_b] + local_cdd_ang - if ti.static(update_cacc): - links_state.cacc_lin[i_l, i_b] = ( - links_state.cacc_lin[i_l, i_b] - + local_cdd_vel - + dofs_state.cdof_vel[i_d, i_b] * dofs_state.acc[i_d, i_b] - ) - links_state.cacc_ang[i_l, i_b] = ( - links_state.cacc_ang[i_l, i_b] - + local_cdd_ang - + dofs_state.cdof_ang[i_d, i_b] * dofs_state.acc[i_d, i_b] - ) + if i_l < entities_info.link_end[i_e]: + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + i_p = links_info.parent_idx[I_l] + + if i_p == -1: + links_state.cdd_vel[i_l, i_b] = -rigid_global_info.gravity[i_b] * ( + 1 - entities_info.gravity_compensation[i_e] + ) + links_state.cdd_ang[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) + if ti.static(update_cacc): + links_state.cacc_lin[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) + links_state.cacc_ang[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) + else: + links_state.cdd_vel[i_l, i_b] = links_state.cdd_vel[i_p, i_b] + links_state.cdd_ang[i_l, i_b] = links_state.cdd_ang[i_p, i_b] + if ti.static(update_cacc): + links_state.cacc_lin[i_l, i_b] = links_state.cacc_lin[i_p, i_b] + links_state.cacc_ang[i_l, i_b] = links_state.cacc_ang[i_p, i_b] + + for i_d_ in ( + range(links_info.dof_start[I_l], links_info.dof_end[I_l]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_link)) + ): + i_d = i_d_ if ti.static(not is_backward) else (i_d_ + links_info.dof_start[I_l]) + + if i_d < links_info.dof_end[I_l]: + # cacc = cacc_parent + cdofdot * qvel + cdof * qacc + local_cdd_vel = dofs_state.cdofd_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] + local_cdd_ang = dofs_state.cdofd_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] + links_state.cdd_vel[i_l, i_b] += local_cdd_vel + links_state.cdd_ang[i_l, i_b] += local_cdd_ang + if ti.static(update_cacc): + links_state.cacc_lin[i_l, i_b] += ( + local_cdd_vel + dofs_state.cdof_vel[i_d, i_b] * dofs_state.acc[i_d, i_b] + ) + links_state.cacc_ang[i_l, i_b] += ( + local_cdd_ang + dofs_state.cdof_ang[i_d, i_b] * dofs_state.acc[i_d, i_b] + ) @ti.func @@ -5308,16 +6363,37 @@ def func_update_force( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): - _B = links_state.pos.shape[1] - n_links = links_info.root_idx.shape[0] - n_entities = entities_info.n_links.shape[0] - - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ( + ti.ndrange(1, links_state.pos.shape[1]) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.ndrange(links_info.root_idx.shape[0], links_state.pos.shape[1]) + ): + for i_1 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_links)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + ): + if i_1 < ( + rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): + i_l = ( + rigid_global_info.awake_links[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) f1_ang, f1_vel = gu.inertial_mul( links_state.cinr_pos[i_l, i_b], @@ -5333,70 +6409,65 @@ def func_update_force( links_state.cd_vel[i_l, i_b], links_state.cd_ang[i_l, i_b], ) - f2_ang, f2_vel = gu.motion_cross_force( + f3_ang, f3_vel = gu.motion_cross_force( links_state.cd_ang[i_l, i_b], links_state.cd_vel[i_l, i_b], f2_ang, f2_vel ) links_state.cfrc_vel[i_l, i_b] = ( - f1_vel + f2_vel + links_state.cfrc_applied_vel[i_l, i_b] + links_state.cfrc_coupling_vel[i_l, i_b] + f1_vel + f3_vel + links_state.cfrc_applied_vel[i_l, i_b] + links_state.cfrc_coupling_vel[i_l, i_b] ) links_state.cfrc_ang[i_l, i_b] = ( - f1_ang + f2_ang + links_state.cfrc_applied_ang[i_l, i_b] + links_state.cfrc_coupling_ang[i_l, i_b] + f1_ang + f3_ang + links_state.cfrc_applied_ang[i_l, i_b] + links_state.cfrc_coupling_ang[i_l, i_b] ) - for i_b in range(_B): - for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): - i_e = rigid_global_info.awake_entities[i_e_, i_b] - for i_l_ in range(entities_info.n_links[i_e]): - i_l = entities_info.link_end[i_e] - 1 - i_l_ - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - i_p = links_info.parent_idx[I_l] - if i_p != -1: - links_state.cfrc_vel[i_p, i_b] = links_state.cfrc_vel[i_p, i_b] + links_state.cfrc_vel[i_l, i_b] - links_state.cfrc_ang[i_p, i_b] = links_state.cfrc_ang[i_p, i_b] + links_state.cfrc_ang[i_l, i_b] - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l, i_b in ti.ndrange(n_links, _B): - f1_ang, f1_vel = gu.inertial_mul( - links_state.cinr_pos[i_l, i_b], - links_state.cinr_inertial[i_l, i_b], - links_state.cinr_mass[i_l, i_b], - links_state.cdd_vel[i_l, i_b], - links_state.cdd_ang[i_l, i_b], - ) - f2_ang, f2_vel = gu.inertial_mul( - links_state.cinr_pos[i_l, i_b], - links_state.cinr_inertial[i_l, i_b], - links_state.cinr_mass[i_l, i_b], - links_state.cd_vel[i_l, i_b], - links_state.cd_ang[i_l, i_b], - ) - f2_ang, f2_vel = gu.motion_cross_force( - links_state.cd_ang[i_l, i_b], links_state.cd_vel[i_l, i_b], f2_ang, f2_vel - ) - - links_state.cfrc_vel[i_l, i_b] = ( - f1_vel + f2_vel + links_state.cfrc_applied_vel[i_l, i_b] + links_state.cfrc_coupling_vel[i_l, i_b] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ( + ti.ndrange(1, links_state.pos.shape[1]) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.ndrange(entities_info.n_links.shape[0], links_state.pos.shape[1]) + ): + for i_1 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_entities[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) ) - links_state.cfrc_ang[i_l, i_b] = ( - f1_ang + f2_ang + links_state.cfrc_applied_ang[i_l, i_b] + links_state.cfrc_coupling_ang[i_l, i_b] + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_entities)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) ) + ): + if i_1 < ( + rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): + i_e = ( + rigid_global_info.awake_entities[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_e, i_b in ti.ndrange(n_entities, _B): - for i_l_ in range(entities_info.n_links[i_e]): - i_l = entities_info.link_end[i_e] - 1 - i_l_ - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - i_p = links_info.parent_idx[I_l] - if i_p != -1: - links_state.cfrc_vel[i_p, i_b] = links_state.cfrc_vel[i_p, i_b] + links_state.cfrc_vel[i_l, i_b] - links_state.cfrc_ang[i_p, i_b] = links_state.cfrc_ang[i_p, i_b] + links_state.cfrc_ang[i_l, i_b] + for i_l_ in ( + range(entities_info.n_links[i_e]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) + ): + if i_l_ < entities_info.n_links[i_e]: + i_l = entities_info.link_end[i_e] - 1 - i_l_ + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + i_p = links_info.parent_idx[I_l] + if i_p != -1: + links_state.cfrc_vel[i_p, i_b] += links_state.cfrc_vel[i_l, i_b] + links_state.cfrc_ang[i_p, i_b] += links_state.cfrc_ang[i_l, i_b] # Clear coupling forces after use ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l, i_b in ti.ndrange(n_links, _B): - links_state.cfrc_coupling_ang[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) - links_state.cfrc_coupling_vel[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) + for I in ti.grouped(ti.ndrange(*links_state.cfrc_coupling_ang.shape)): + links_state.cfrc_coupling_ang[I] = ti.Vector.zero(gs.ti_float, 3) + links_state.cfrc_coupling_vel[I] = ti.Vector.zero(gs.ti_float, 3) @ti.func @@ -5430,49 +6501,75 @@ def func_bias_force( links_info: array_class.LinksInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): - _B = dofs_state.ctrl_mode.shape[1] - n_links = links_info.root_idx.shape[0] - - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ( + ti.ndrange(1, dofs_state.ctrl_mode.shape[1]) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.ndrange(links_info.root_idx.shape[0], dofs_state.ctrl_mode.shape[1]) + ): + for i_1 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_links)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + ): + if i_1 < ( + rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): + i_l = ( + rigid_global_info.awake_links[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - for i_d in range(links_info.dof_start[I_l], links_info.dof_end[I_l]): - dofs_state.qf_bias[i_d, i_b] = dofs_state.cdof_ang[i_d, i_b].dot( - links_state.cfrc_ang[i_l, i_b] - ) + dofs_state.cdof_vel[i_d, i_b].dot(links_state.cfrc_vel[i_l, i_b]) - - dofs_state.force[i_d, i_b] = ( - dofs_state.qf_passive[i_d, i_b] - - dofs_state.qf_bias[i_d, i_b] - + dofs_state.qf_applied[i_d, i_b] - # + self.dofs_state.qf_actuator[i_d, i_b] - ) - - dofs_state.qf_smooth[i_d, i_b] = dofs_state.force[i_d, i_b] - - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l, i_b in ti.ndrange(n_links, _B): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + for i_d_ in ( + range(links_info.dof_start[I_l], links_info.dof_end[I_l]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_link)) + ): + i_d = i_d_ if ti.static(not is_backward) else (i_d_ + links_info.dof_start[I_l]) + if i_d < links_info.dof_end[I_l]: + dofs_state.qf_bias[i_d, i_b] = dofs_state.cdof_ang[i_d, i_b].dot( + links_state.cfrc_ang[i_l, i_b] + ) + dofs_state.cdof_vel[i_d, i_b].dot(links_state.cfrc_vel[i_l, i_b]) + + dofs_state.force[i_d, i_b] = ( + dofs_state.qf_passive[i_d, i_b] + - dofs_state.qf_bias[i_d, i_b] + + dofs_state.qf_applied[i_d, i_b] + # + self.dofs_state.qf_actuator[i_d, i_b] + ) - for i_d in range(links_info.dof_start[I_l], links_info.dof_end[I_l]): - dofs_state.qf_bias[i_d, i_b] = dofs_state.cdof_ang[i_d, i_b].dot( - links_state.cfrc_ang[i_l, i_b] - ) + dofs_state.cdof_vel[i_d, i_b].dot(links_state.cfrc_vel[i_l, i_b]) + dofs_state.qf_smooth[i_d, i_b] = dofs_state.force[i_d, i_b] - dofs_state.force[i_d, i_b] = ( - dofs_state.qf_passive[i_d, i_b] - - dofs_state.qf_bias[i_d, i_b] - + dofs_state.qf_applied[i_d, i_b] - # + self.dofs_state.qf_actuator[i_d, i_b] - ) - dofs_state.qf_smooth[i_d, i_b] = dofs_state.force[i_d, i_b] +@ti.kernel +def kernel_compute_qacc( + dofs_state: array_class.DofsState, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), + is_backward: ti.template(), +): + func_compute_qacc( + dofs_state=dofs_state, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, + ) @ti.func @@ -5481,32 +6578,57 @@ def func_compute_qacc( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): - _B = dofs_state.ctrl_mode.shape[1] - n_entities = entities_info.n_links.shape[0] - func_solve_mass( vec=dofs_state.force, out=dofs_state.acc_smooth, + out_bw=dofs_state.acc_smooth_bw, entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, ) - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) - for i_b in range(_B): - for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): - i_e = rigid_global_info.awake_entities[i_e_, i_b] - for i_d1_ in range(entities_info.n_dofs[i_e]): + # Assume this is the outermost loop + ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) + for i_0, i_b in ( + ti.ndrange(1, dofs_state.ctrl_mode.shape[1]) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.ndrange(entities_info.n_links.shape[0], dofs_state.ctrl_mode.shape[1]) + ): + for i_1 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_entities[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_entities)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + ): + if i_1 < ( + rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): + i_e = ( + rigid_global_info.awake_entities[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) + + for i_d1_ in ( + range(entities_info.n_dofs[i_e]) + if ti.static(not is_backward) + else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) + ): i_d1 = entities_info.dof_start[i_e] + i_d1_ - dofs_state.acc[i_d1, i_b] = dofs_state.acc_smooth[i_d1, i_b] - else: - ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) - for i_e, i_b in ti.ndrange(n_entities, _B): - for i_d1_ in range(entities_info.n_dofs[i_e]): - i_d1 = entities_info.dof_start[i_e] + i_d1_ - dofs_state.acc[i_d1, i_b] = dofs_state.acc_smooth[i_d1, i_b] + if i_d1 < entities_info.dof_end[i_e]: + dofs_state.acc[i_d1, i_b] = dofs_state.acc_smooth[i_d1, i_b] @ti.func @@ -5516,56 +6638,81 @@ def func_integrate( joints_info: array_class.JointsInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), + is_backward: ti.template(), ): - EPS = rigid_global_info.EPS[None] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ( + (ti.ndrange(1, dofs_state.ctrl_mode.shape[1])) + if ti.static(static_rigid_sim_config.use_hibernation) + else (ti.ndrange(dofs_state.ctrl_mode.shape[0], dofs_state.ctrl_mode.shape[1])) + ): + for i_1 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_dofs[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_dofs)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + ): + if i_1 < (rigid_global_info.n_awake_dofs[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1): + i_d = ( + rigid_global_info.awake_dofs[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) - _B = dofs_state.ctrl_mode.shape[1] - n_dofs = dofs_state.ctrl_mode.shape[0] - n_links = links_info.root_idx.shape[0] - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_d_ in range(rigid_global_info.n_awake_dofs[i_b]): - i_d = rigid_global_info.awake_dofs[i_d_, i_b] - dofs_state.vel[i_d, i_b] = ( + dofs_state.vel_next[i_d, i_b] = ( dofs_state.vel[i_d, i_b] + dofs_state.acc[i_d, i_b] * rigid_global_info.substep_dt[None] ) - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ( + (ti.ndrange(1, dofs_state.ctrl_mode.shape[1])) + if ti.static(static_rigid_sim_config.use_hibernation) + else (ti.ndrange(links_info.root_idx.shape[0], dofs_state.ctrl_mode.shape[1])) + ): + for i_1 in ( + ( + # Dynamic inner loop for forward pass + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ) + if ti.static(not is_backward) + else ( + # Static inner loop for backward pass + ti.static(range(static_rigid_sim_config.max_n_awake_links)) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.static(range(1)) + ) + ): + if i_1 < ( + rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): + i_l = ( + rigid_global_info.awake_links[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + if links_info.n_dofs[I_l] > 0: + EPS = rigid_global_info.EPS[None] + dof_start = links_info.dof_start[I_l] + q_start = links_info.q_start[I_l] + q_end = links_info.q_end[I_l] - for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]): + i_j = links_info.joint_start[I_l] I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j - dof_start = joints_info.dof_start[I_j] - q_start = joints_info.q_start[I_j] - q_end = joints_info.q_end[I_j] - joint_type = joints_info.type[I_j] - if joint_type == gs.JOINT_TYPE.FREE: - rot = ti.Vector( - [ - rigid_global_info.qpos[q_start + 3, i_b], - rigid_global_info.qpos[q_start + 4, i_b], - rigid_global_info.qpos[q_start + 5, i_b], - rigid_global_info.qpos[q_start + 6, i_b], - ] - ) - ang = ( - ti.Vector( - [ - dofs_state.vel[dof_start + 3, i_b], - dofs_state.vel[dof_start + 4, i_b], - dofs_state.vel[dof_start + 5, i_b], - ] - ) - * rigid_global_info.substep_dt[None] - ) - qrot = gu.ti_rotvec_to_quat(ang, EPS) - rot = gu.ti_transform_quat_by_quat(qrot, rot) + if joint_type == gs.JOINT_TYPE.FREE: pos = ti.Vector( [ rigid_global_info.qpos[q_start, i_b], @@ -5575,118 +6722,346 @@ def func_integrate( ) vel = ti.Vector( [ - dofs_state.vel[dof_start, i_b], - dofs_state.vel[dof_start + 1, i_b], - dofs_state.vel[dof_start + 2, i_b], + dofs_state.vel_next[dof_start, i_b], + dofs_state.vel_next[dof_start + 1, i_b], + dofs_state.vel_next[dof_start + 2, i_b], ] ) - pos = pos + vel * rigid_global_info.substep_dt[None] + pos += vel * rigid_global_info.substep_dt[None] for j in ti.static(range(3)): - rigid_global_info.qpos[q_start + j, i_b] = pos[j] - for j in ti.static(range(4)): - rigid_global_info.qpos[q_start + j + 3, i_b] = rot[j] - elif joint_type == gs.JOINT_TYPE.FIXED: - pass - elif joint_type == gs.JOINT_TYPE.SPHERICAL: - rot = ti.Vector( + rigid_global_info.qpos_next[q_start + j, i_b] = pos[j] + if joint_type == gs.JOINT_TYPE.SPHERICAL or joint_type == gs.JOINT_TYPE.FREE: + rot_offset = 3 if joint_type == gs.JOINT_TYPE.FREE else 0 + rot0 = ti.Vector( [ - rigid_global_info.qpos[q_start + 0, i_b], - rigid_global_info.qpos[q_start + 1, i_b], - rigid_global_info.qpos[q_start + 2, i_b], - rigid_global_info.qpos[q_start + 3, i_b], + rigid_global_info.qpos[q_start + rot_offset + 0, i_b], + rigid_global_info.qpos[q_start + rot_offset + 1, i_b], + rigid_global_info.qpos[q_start + rot_offset + 2, i_b], + rigid_global_info.qpos[q_start + rot_offset + 3, i_b], ] ) ang = ( ti.Vector( [ - dofs_state.vel[dof_start + 3, i_b], - dofs_state.vel[dof_start + 4, i_b], - dofs_state.vel[dof_start + 5, i_b], + dofs_state.vel_next[dof_start + rot_offset + 0, i_b], + dofs_state.vel_next[dof_start + rot_offset + 1, i_b], + dofs_state.vel_next[dof_start + rot_offset + 2, i_b], ] ) * rigid_global_info.substep_dt[None] ) qrot = gu.ti_rotvec_to_quat(ang, EPS) - rot = gu.ti_transform_quat_by_quat(qrot, rot) + rot = gu.ti_transform_quat_by_quat(qrot, rot0) for j in ti.static(range(4)): - rigid_global_info.qpos[q_start + j, i_b] = rot[j] - + rigid_global_info.qpos_next[q_start + j + rot_offset, i_b] = rot[j] else: - for j in range(q_end - q_start): - rigid_global_info.qpos[q_start + j, i_b] = ( - rigid_global_info.qpos[q_start + j, i_b] - + dofs_state.vel[dof_start + j, i_b] * rigid_global_info.substep_dt[None] - ) + for j_ in ( + (range(q_end - q_start)) + if ti.static(not is_backward) + else (ti.static(range(static_rigid_sim_config.max_n_qs_per_link))) + ): + j = q_start + j_ + if j < q_end: + rigid_global_info.qpos_next[j, i_b] = ( + rigid_global_info.qpos[j, i_b] + + dofs_state.vel_next[dof_start + j_, i_b] * rigid_global_info.substep_dt[None] + ) - else: + +@ti.func +def func_copy_next_to_curr( + dofs_state: array_class.DofsState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for I in ti.grouped(ti.ndrange(*dofs_state.vel.shape)): + dofs_state.vel[I] = dofs_state.vel_next[I] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for I in ti.grouped(ti.ndrange(*rigid_global_info.qpos.shape)): + rigid_global_info.qpos[I] = rigid_global_info.qpos_next[I] + + +@ti.func +def func_copy_next_to_curr_grad( + f: ti.int32, + dofs_state: array_class.DofsState, + rigid_global_info: array_class.RigidGlobalInfo, + rigid_adjoint_cache: array_class.RigidAdjointCache, + static_rigid_sim_config: ti.template(), +): + n_dofs = dofs_state.vel.shape[0] + n_qs = rigid_global_info.qpos.shape[0] + _B = dofs_state.vel.shape[1] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_d, i_b in ti.ndrange(n_dofs, _B): + dofs_state.vel_next.grad[i_d, i_b] = dofs_state.vel.grad[i_d, i_b] + dofs_state.vel.grad[i_d, i_b] = 0.0 + dofs_state.vel[i_d, i_b] = rigid_adjoint_cache.dofs_vel[f, i_d, i_b] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_q, i_b in ti.ndrange(n_qs, _B): + rigid_global_info.qpos_next.grad[i_q, i_b] = rigid_global_info.qpos.grad[i_q, i_b] + rigid_global_info.qpos.grad[i_q, i_b] = 0.0 + rigid_global_info.qpos[i_q, i_b] = rigid_adjoint_cache.qpos[f, i_q, i_b] + + +@ti.func +def func_save_adjoint_cache( + f: ti.int32, + dofs_state: array_class.DofsState, + rigid_global_info: array_class.RigidGlobalInfo, + rigid_adjoint_cache: array_class.RigidAdjointCache, + static_rigid_sim_config: ti.template(), +): + n_dofs = dofs_state.vel.shape[0] + n_qs = rigid_global_info.qpos.shape[0] + _B = dofs_state.vel.shape[1] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_d, i_b in ti.ndrange(n_dofs, _B): + rigid_adjoint_cache.dofs_vel[f, i_d, i_b] = dofs_state.vel[i_d, i_b] + rigid_adjoint_cache.dofs_acc[f, i_d, i_b] = dofs_state.acc[i_d, i_b] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_q, i_b in ti.ndrange(n_qs, _B): + rigid_adjoint_cache.qpos[f, i_q, i_b] = rigid_global_info.qpos[i_q, i_b] + + +@ti.func +def func_load_adjoint_cache( + f: ti.int32, + dofs_state: array_class.DofsState, + rigid_global_info: array_class.RigidGlobalInfo, + rigid_adjoint_cache: array_class.RigidAdjointCache, + static_rigid_sim_config: ti.template(), +): + n_dofs = dofs_state.vel.shape[0] + n_qs = rigid_global_info.qpos.shape[0] + _B = dofs_state.vel.shape[1] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_d, i_b in ti.ndrange(n_dofs, _B): + dofs_state.vel[i_d, i_b] = rigid_adjoint_cache.dofs_vel[f, i_d, i_b] + dofs_state.acc[i_d, i_b] = rigid_adjoint_cache.dofs_acc[f, i_d, i_b] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_q, i_b in ti.ndrange(n_qs, _B): + rigid_global_info.qpos[i_q, i_b] = rigid_adjoint_cache.qpos[f, i_q, i_b] + + +@ti.kernel(fastcache=gs.use_fastcache) +def kernel_prepare_backward_substep( + f: ti.int32, + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + joints_state: array_class.JointsState, + joints_info: array_class.JointsInfo, + dofs_state: array_class.DofsState, + dofs_info: array_class.DofsInfo, + geoms_state: array_class.GeomsState, + geoms_info: array_class.GeomsInfo, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + dofs_state_adjoint_cache: array_class.DofsState, + links_state_adjoint_cache: array_class.LinksState, + joints_state_adjoint_cache: array_class.JointsState, + geoms_state_adjoint_cache: array_class.GeomsState, + rigid_adjoint_cache: array_class.RigidAdjointCache, + static_rigid_sim_config: ti.template(), +): + # Load the current state from adjoint cache + func_load_adjoint_cache( + f=f, + dofs_state=dofs_state, + rigid_global_info=rigid_global_info, + rigid_adjoint_cache=rigid_adjoint_cache, + static_rigid_sim_config=static_rigid_sim_config, + ) + + # If mujoco compatibility is disabled, update the cartesian space and save the results to adjoint cache. This is + # because the cartesian space is overwritten later by other kernels if mujoco compatibility was disabled. + if not static_rigid_sim_config.enable_mujoco_compatibility: ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_d, i_b in ti.ndrange(n_dofs, _B): - dofs_state.vel[i_d, i_b] = ( - dofs_state.vel[i_d, i_b] + dofs_state.acc[i_d, i_b] * rigid_global_info.substep_dt[None] + for i_b in range(links_state.pos.shape[1]): + func_update_cartesian_space( + i_b=i_b, + links_state=links_state, + links_info=links_info, + joints_state=joints_state, + joints_info=joints_info, + dofs_state=dofs_state, + dofs_info=dofs_info, + geoms_state=geoms_state, + geoms_info=geoms_info, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + force_update_fixed_geoms=False, + is_backward=False, ) + # Save results of [update_cartesian_space] to adjoint cache + func_copy_cartesian_space( + src_dofs_state=dofs_state, + src_links_state=links_state, + src_joints_state=joints_state, + src_geoms_state=geoms_state, + dst_dofs_state=dofs_state_adjoint_cache, + dst_links_state=links_state_adjoint_cache, + dst_joints_state=joints_state_adjoint_cache, + dst_geoms_state=geoms_state_adjoint_cache, + static_rigid_sim_config=static_rigid_sim_config, + ) - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l, i_b in ti.ndrange(n_links, _B): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - if links_info.n_dofs[I_l] == 0: - continue - dof_start = links_info.dof_start[I_l] - q_start = links_info.q_start[I_l] - q_end = links_info.q_end[I_l] +@ti.kernel(fastcache=gs.use_fastcache) +def kernel_begin_backward_substep( + f: ti.int32, + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + joints_state: array_class.JointsState, + joints_info: array_class.JointsInfo, + dofs_state: array_class.DofsState, + dofs_info: array_class.DofsInfo, + geoms_state: array_class.GeomsState, + geoms_info: array_class.GeomsInfo, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + dofs_state_adjoint_cache: array_class.DofsState, + links_state_adjoint_cache: array_class.LinksState, + joints_state_adjoint_cache: array_class.JointsState, + geoms_state_adjoint_cache: array_class.GeomsState, + rigid_adjoint_cache: array_class.RigidAdjointCache, + static_rigid_sim_config: ti.template(), +) -> ti.i32: + is_grad_valid = func_is_grad_valid( + rigid_global_info=rigid_global_info, + dofs_state=dofs_state, + static_rigid_sim_config=static_rigid_sim_config, + ) + if is_grad_valid: + func_copy_next_to_curr_grad( + f=f, + dofs_state=dofs_state, + rigid_global_info=rigid_global_info, + rigid_adjoint_cache=rigid_adjoint_cache, + static_rigid_sim_config=static_rigid_sim_config, + ) - i_j = links_info.joint_start[I_l] - I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j - joint_type = joints_info.type[I_j] + if not static_rigid_sim_config.enable_mujoco_compatibility: + func_copy_cartesian_space( + src_dofs_state=dofs_state_adjoint_cache, + src_links_state=links_state_adjoint_cache, + src_joints_state=joints_state_adjoint_cache, + src_geoms_state=geoms_state_adjoint_cache, + dst_dofs_state=dofs_state, + dst_links_state=links_state, + dst_joints_state=joints_state, + dst_geoms_state=geoms_state, + static_rigid_sim_config=static_rigid_sim_config, + ) - if joint_type == gs.JOINT_TYPE.FREE: - pos = ti.Vector( - [ - rigid_global_info.qpos[q_start, i_b], - rigid_global_info.qpos[q_start + 1, i_b], - rigid_global_info.qpos[q_start + 2, i_b], - ] - ) - vel = ti.Vector( - [ - dofs_state.vel[dof_start, i_b], - dofs_state.vel[dof_start + 1, i_b], - dofs_state.vel[dof_start + 2, i_b], - ] - ) - pos = pos + vel * rigid_global_info.substep_dt[None] - for j in ti.static(range(3)): - rigid_global_info.qpos[q_start + j, i_b] = pos[j] - if joint_type == gs.JOINT_TYPE.SPHERICAL or joint_type == gs.JOINT_TYPE.FREE: - rot_offset = 3 if joint_type == gs.JOINT_TYPE.FREE else 0 - rot = ti.Vector( - [ - rigid_global_info.qpos[q_start + rot_offset + 0, i_b], - rigid_global_info.qpos[q_start + rot_offset + 1, i_b], - rigid_global_info.qpos[q_start + rot_offset + 2, i_b], - rigid_global_info.qpos[q_start + rot_offset + 3, i_b], - ] - ) - ang = ( - ti.Vector( - [ - dofs_state.vel[dof_start + rot_offset + 0, i_b], - dofs_state.vel[dof_start + rot_offset + 1, i_b], - dofs_state.vel[dof_start + rot_offset + 2, i_b], - ] - ) - * rigid_global_info.substep_dt[None] - ) - qrot = gu.ti_rotvec_to_quat(ang, EPS) - rot = gu.ti_transform_quat_by_quat(qrot, rot) - for j in ti.static(range(4)): - rigid_global_info.qpos[q_start + j + rot_offset, i_b] = rot[j] - else: - for j in range(q_end - q_start): - rigid_global_info.qpos[q_start + j, i_b] = ( - rigid_global_info.qpos[q_start + j, i_b] - + dofs_state.vel[dof_start + j, i_b] * rigid_global_info.substep_dt[None] - ) + return is_grad_valid + + +@ti.func +def func_is_grad_valid( + rigid_global_info: array_class.RigidGlobalInfo, + dofs_state: array_class.DofsState, + static_rigid_sim_config: ti.template(), +): + is_valid = True + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for I in ti.grouped(ti.ndrange(*rigid_global_info.qpos.shape)): + if ti.math.isnan(rigid_global_info.qpos.grad[I]): + is_valid = False + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for I in ti.grouped(ti.ndrange(*dofs_state.vel.shape)): + if ti.math.isnan(dofs_state.vel.grad[I]): + is_valid = False + + return is_valid + + +@ti.func +def func_copy_cartesian_space( + src_dofs_state: array_class.DofsState, + src_links_state: array_class.LinksState, + src_joints_state: array_class.JointsState, + src_geoms_state: array_class.GeomsState, + dst_dofs_state: array_class.DofsState, + dst_links_state: array_class.LinksState, + dst_joints_state: array_class.JointsState, + dst_geoms_state: array_class.GeomsState, + static_rigid_sim_config: ti.template(), +): + # Copy outputs of [kernel_update_cartesian_space] among [dofs, links, joints, geoms] states. This is used to restore + # the outputs that were overwritten if we disabled mujoco compatibility for backward pass. + + # dofs state + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for I in ti.grouped(ti.ndrange(*src_dofs_state.pos.shape)): + # pos, cdof_ang, cdof_vel, cdofvel_ang, cdofvel_vel, cdofd_ang, cdofd_vel + dst_dofs_state.pos[I] = src_dofs_state.pos[I] + dst_dofs_state.cdof_ang[I] = src_dofs_state.cdof_ang[I] + dst_dofs_state.cdof_vel[I] = src_dofs_state.cdof_vel[I] + dst_dofs_state.cdofvel_ang[I] = src_dofs_state.cdofvel_ang[I] + dst_dofs_state.cdofvel_vel[I] = src_dofs_state.cdofvel_vel[I] + dst_dofs_state.cdofd_ang[I] = src_dofs_state.cdofd_ang[I] + dst_dofs_state.cdofd_vel[I] = src_dofs_state.cdofd_vel[I] + + # links state + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for I in ti.grouped(ti.ndrange(*src_links_state.pos.shape)): + # pos, quat, root_COM, mass_sum, i_pos, i_quat, cinr_inertial, cinr_pos, cinr_quat, cinr_mass, j_pos, j_quat, + # cd_vel, cd_ang + dst_links_state.pos[I] = src_links_state.pos[I] + dst_links_state.quat[I] = src_links_state.quat[I] + dst_links_state.root_COM[I] = src_links_state.root_COM[I] + dst_links_state.mass_sum[I] = src_links_state.mass_sum[I] + dst_links_state.i_pos[I] = src_links_state.i_pos[I] + dst_links_state.i_quat[I] = src_links_state.i_quat[I] + dst_links_state.cinr_inertial[I] = src_links_state.cinr_inertial[I] + dst_links_state.cinr_pos[I] = src_links_state.cinr_pos[I] + dst_links_state.cinr_quat[I] = src_links_state.cinr_quat[I] + dst_links_state.cinr_mass[I] = src_links_state.cinr_mass[I] + dst_links_state.j_pos[I] = src_links_state.j_pos[I] + dst_links_state.j_quat[I] = src_links_state.j_quat[I] + dst_links_state.cd_vel[I] = src_links_state.cd_vel[I] + dst_links_state.cd_ang[I] = src_links_state.cd_ang[I] + + # joints state + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for I in ti.grouped(ti.ndrange(*src_joints_state.xanchor.shape)): + # xanchor, xaxis + dst_joints_state.xanchor[I] = src_joints_state.xanchor[I] + dst_joints_state.xaxis[I] = src_joints_state.xaxis[I] + + # geoms state + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for I in ti.grouped(ti.ndrange(*src_geoms_state.pos.shape)): + # pos, quat, verts_updated + dst_geoms_state.pos[I] = src_geoms_state.pos[I] + dst_geoms_state.quat[I] = src_geoms_state.quat[I] + dst_geoms_state.verts_updated[I] = src_geoms_state.verts_updated[I] + + +@ti.kernel(fastcache=gs.use_fastcache) +def kernel_copy_acc( + f: ti.int32, + dofs_state: array_class.DofsState, + rigid_adjoint_cache: array_class.RigidAdjointCache, + static_rigid_sim_config: ti.template(), +): + n_dofs = dofs_state.vel.shape[0] + _B = dofs_state.vel.shape[1] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_d, i_b in ti.ndrange(n_dofs, _B): + dofs_state.acc[i_d, i_b] = rigid_adjoint_cache.dofs_acc[f, i_d, i_b] @ti.func @@ -5822,6 +7197,7 @@ def kernel_update_vgeoms_render_T( def kernel_get_state( qpos: ti.types.ndarray(), vel: ti.types.ndarray(), + acc: ti.types.ndarray(), links_pos: ti.types.ndarray(), links_quat: ti.types.ndarray(), i_pos_shift: ti.types.ndarray(), @@ -5847,6 +7223,7 @@ def kernel_get_state( ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_d, i_b in ti.ndrange(n_dofs, _B): vel[i_b, i_d] = dofs_state.vel[i_d, i_b] + acc[i_b, i_d] = dofs_state.acc[i_d, i_b] ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_l, i_b in ti.ndrange(n_links, _B): @@ -5867,6 +7244,7 @@ def kernel_get_state( def kernel_set_state( qpos: ti.types.ndarray(), dofs_vel: ti.types.ndarray(), + dofs_acc: ti.types.ndarray(), links_pos: ti.types.ndarray(), links_quat: ti.types.ndarray(), i_pos_shift: ti.types.ndarray(), @@ -5892,6 +7270,7 @@ def kernel_set_state( ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_d, i_b_ in ti.ndrange(n_dofs, envs_idx.shape[0]): dofs_state.vel[i_d, envs_idx[i_b_]] = dofs_vel[envs_idx[i_b_], i_d] + dofs_state.acc[i_d, envs_idx[i_b_]] = dofs_acc[envs_idx[i_b_], i_d] dofs_state.ctrl_force[i_d, envs_idx[i_b_]] = gs.ti_float(0.0) dofs_state.ctrl_mode[i_d, envs_idx[i_b_]] = gs.CTRL_MODE.FORCE @@ -5909,6 +7288,39 @@ def kernel_set_state( geoms_state.friction_ratio[i_l, envs_idx[i_b_]] = friction_ratio[envs_idx[i_b_], i_l] +@ti.kernel(fastcache=gs.use_fastcache) +def kernel_get_state_grad( + qpos_grad: ti.types.ndarray(), + vel_grad: ti.types.ndarray(), + links_pos_grad: ti.types.ndarray(), + links_quat_grad: ti.types.ndarray(), + links_state: array_class.LinksState, + dofs_state: array_class.DofsState, + geoms_state: array_class.GeomsState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + n_qs = qpos_grad.shape[1] + n_dofs = vel_grad.shape[1] + n_links = links_pos_grad.shape[1] + _B = qpos_grad.shape[0] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_q, i_b in ti.ndrange(n_qs, _B): + rigid_global_info.qpos.grad[i_q, i_b] += qpos_grad[i_b, i_q] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_d, i_b in ti.ndrange(n_dofs, _B): + dofs_state.vel.grad[i_d, i_b] += vel_grad[i_b, i_d] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l, i_b in ti.ndrange(n_links, _B): + for j in ti.static(range(3)): + links_state.pos.grad[i_l, i_b][j] += links_pos_grad[i_b, i_l, j] + for j in ti.static(range(4)): + links_state.quat.grad[i_l, i_b][j] += links_quat_grad[i_b, i_l, j] + + @ti.kernel(fastcache=gs.use_fastcache) def kernel_set_links_pos( relative: ti.i32, @@ -5944,6 +7356,35 @@ def kernel_set_links_pos( ) +@ti.kernel(fastcache=gs.use_fastcache) +def kernel_set_links_pos_grad( + relative: ti.i32, + pos_grad: ti.types.ndarray(), + links_idx: ti.types.ndarray(), + envs_idx: ti.types.ndarray(), + links_info: array_class.LinksInfo, + links_state: array_class.LinksState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l_, i_b_ in ti.ndrange(links_idx.shape[0], envs_idx.shape[0]): + i_b = envs_idx[i_b_] + i_l = links_idx[i_l_] + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + + if links_info.parent_idx[I_l] == -1 and links_info.is_fixed[I_l]: + for j in ti.static(range(3)): + pos_grad[i_b_, i_l_, j] = links_state.pos.grad[i_l, i_b][j] + links_state.pos.grad[i_l, i_b][j] = 0.0 + else: + q_start = links_info.q_start[I_l] + for j in ti.static(range(3)): + pos_grad[i_b_, i_l_, j] = rigid_global_info.qpos.grad[q_start + j, i_b] + rigid_global_info.qpos.grad[q_start + j, i_b] = 0.0 + + @ti.kernel(fastcache=gs.use_fastcache) def kernel_set_links_quat( relative: ti.i32, @@ -5998,6 +7439,34 @@ def kernel_set_links_quat( rigid_global_info.qpos[q_start + j + 3, i_b] = quat[i_b_, i_l_, j] +@ti.kernel(fastcache=gs.use_fastcache) +def kernel_set_links_quat_grad( + relative: ti.i32, + quat_grad: ti.types.ndarray(), + links_idx: ti.types.ndarray(), + envs_idx: ti.types.ndarray(), + links_info: array_class.LinksInfo, + links_state: array_class.LinksState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l_, i_b_ in ti.ndrange(links_idx.shape[0], envs_idx.shape[0]): + i_b = envs_idx[i_b_] + i_l = links_idx[i_l_] + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + + if links_info.parent_idx[I_l] == -1 and links_info.is_fixed[I_l]: + for j in ti.static(range(4)): + quat_grad[i_b_, i_l_, j] = links_state.quat.grad[i_l, i_b][j] + links_state.quat.grad[i_l, i_b][j] = 0.0 + else: + q_start = links_info.q_start[I_l] + for j in ti.static(range(4)): + quat_grad[i_b_, i_l_, j] = rigid_global_info.qpos.grad[q_start + j + 3, i_b] + rigid_global_info.qpos.grad[q_start + j + 3, i_b] = 0.0 + + @ti.kernel(fastcache=gs.use_fastcache) def kernel_set_links_mass_shift( mass: ti.types.ndarray(), @@ -6286,6 +7755,20 @@ def kernel_set_dofs_velocity( dofs_state.vel[dofs_idx[i_d_], envs_idx[i_b_]] = velocity[i_b_, i_d_] +@ti.kernel(fastcache=gs.use_fastcache) +def kernel_set_dofs_velocity_grad( + velocity_grad: ti.types.ndarray(), + dofs_idx: ti.types.ndarray(), + envs_idx: ti.types.ndarray(), + dofs_state: array_class.DofsState, + static_rigid_sim_config: ti.template(), +): + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) + for i_d_, i_b_ in ti.ndrange(dofs_idx.shape[0], envs_idx.shape[0]): + velocity_grad[i_b_, i_d_] = dofs_state.vel.grad[dofs_idx[i_d_], envs_idx[i_b_]] + dofs_state.vel.grad[dofs_idx[i_d_], envs_idx[i_b_]] = 0.0 + + @ti.kernel(fastcache=gs.use_fastcache) def kernel_set_dofs_zero_velocity( dofs_idx: ti.types.ndarray(), diff --git a/genesis/engine/states/__init__.py b/genesis/engine/states/__init__.py index e69de29bb..2e316a1de 100644 --- a/genesis/engine/states/__init__.py +++ b/genesis/engine/states/__init__.py @@ -0,0 +1,2 @@ +from .solvers import * +from .cache import * diff --git a/genesis/engine/states/entities.py b/genesis/engine/states/entities.py index b0135547a..b6ee1f6dd 100644 --- a/genesis/engine/states/entities.py +++ b/genesis/engine/states/entities.py @@ -188,3 +188,41 @@ def vel(self): @property def active(self): return self._active + + +class RigidEntityState(RBC): + """ + Dynamic state queried from a genesis RigidEntity. + """ + + def __init__(self, entity, s_global): + self._entity = entity + self._s_global = s_global + + num_batch = self._entity._solver._B + requires_grad = self._entity.scene.requires_grad + scene = self._entity.scene + self._pos = gs.zeros((num_batch, 3), dtype=float, requires_grad=requires_grad, scene=scene) + self._quat = gs.zeros((num_batch, 4), dtype=float, requires_grad=requires_grad, scene=scene) + + def serializable(self): + self._entity = None + + self._pos = self._pos.detach() + self._quat = self._quat.detach() + + @property + def entity(self): + return self._entity + + @property + def s_global(self): + return self._s_global + + @property + def pos(self): + return self._pos + + @property + def quat(self): + return self._quat diff --git a/genesis/engine/states/solvers.py b/genesis/engine/states/solvers.py index c1182c501..5135346e9 100644 --- a/genesis/engine/states/solvers.py +++ b/genesis/engine/states/solvers.py @@ -48,9 +48,11 @@ class RigidSolverState: Dynamic state queried from a RigidSolver. """ - def __init__(self, scene): + def __init__(self, scene, s_global): self.scene = scene + self._s_global = s_global + _B = scene.sim.rigid_solver._B args = { "dtype": gs.tc_float, @@ -59,6 +61,7 @@ def __init__(self, scene): } self.qpos = gs.zeros((_B, scene.sim.rigid_solver.n_qs), **args) self.dofs_vel = gs.zeros((_B, scene.sim.rigid_solver.n_dofs), **args) + self.dofs_acc = gs.zeros((_B, scene.sim.rigid_solver.n_dofs), **args) self.links_pos = gs.zeros((_B, scene.sim.rigid_solver.n_links, 3), **args) self.links_quat = gs.zeros((_B, scene.sim.rigid_solver.n_links, 4), **args) self.i_pos_shift = gs.zeros((_B, scene.sim.rigid_solver.n_links, 3), **args) @@ -75,6 +78,10 @@ def serializable(self): self.mass_shift = self.mass_shift.detach() self.friction_ratio = self.friction_ratio.detach() + @property + def s_global(self): + return self._s_global + class AvatarSolverState: """ diff --git a/genesis/utils/array_class.py b/genesis/utils/array_class.py index d8dcb07f7..76922313f 100644 --- a/genesis/utils/array_class.py +++ b/genesis/utils/array_class.py @@ -18,6 +18,7 @@ V_MAT = ti.Matrix.ndarray if gs.use_ndarray else ti.Matrix.field DATA_ORIENTED = partial(dataclasses.dataclass, frozen=True) if gs.use_ndarray else ti.data_oriented +PLACEHOLDER = V(dtype=gs.ti_float, shape=()) def maybe_shape(shape, is_on): @@ -73,6 +74,7 @@ def V_SCALAR_FROM(dtype, value): @DATA_ORIENTED class StructRigidGlobalInfo(metaclass=BASE_METACLASS): + # *_bw: Cache for backward pass n_awake_dofs: V_ANNOTATION awake_dofs: V_ANNOTATION n_awake_entities: V_ANNOTATION @@ -81,11 +83,13 @@ class StructRigidGlobalInfo(metaclass=BASE_METACLASS): awake_links: V_ANNOTATION qpos0: V_ANNOTATION qpos: V_ANNOTATION + qpos_next: V_ANNOTATION links_T: V_ANNOTATION envs_offset: V_ANNOTATION geoms_init_AABB: V_ANNOTATION mass_mat: V_ANNOTATION mass_mat_L: V_ANNOTATION + mass_mat_L_bw: V_ANNOTATION mass_mat_D_inv: V_ANNOTATION mass_mat_mask: V_ANNOTATION meaninertia: V_ANNOTATION @@ -108,6 +112,7 @@ class StructRigidGlobalInfo(metaclass=BASE_METACLASS): def get_rigid_global_info(solver): _B = solver._B + requires_grad = solver._requires_grad mass_mat_shape = (solver.n_dofs_, solver.n_dofs_, _B) if math.prod(mass_mat_shape) > np.iinfo(np.int32).max: @@ -126,14 +131,16 @@ def get_rigid_global_info(solver): awake_entities=V(dtype=gs.ti_int, shape=(solver.n_entities_, _B)), awake_links=V(dtype=gs.ti_int, shape=(solver.n_links_, _B)), qpos0=V(dtype=gs.ti_float, shape=(solver.n_qs_, _B)), - qpos=V(dtype=gs.ti_float, shape=(solver.n_qs_, _B)), + qpos=V(dtype=gs.ti_float, shape=(solver.n_qs_, _B), needs_grad=requires_grad), + qpos_next=V(dtype=gs.ti_float, shape=(solver.n_qs_, _B), needs_grad=requires_grad), links_T=V_MAT(n=4, m=4, dtype=gs.ti_float, shape=(solver.n_links_,)), geoms_init_AABB=V_VEC(3, dtype=gs.ti_float, shape=(solver.n_geoms_, 8)), - mass_mat_D_inv=V(dtype=gs.ti_float, shape=(solver.n_dofs_, _B)), + mass_mat=V(dtype=gs.ti_float, shape=mass_mat_shape, needs_grad=requires_grad), + mass_mat_L=V(dtype=gs.ti_float, shape=mass_mat_shape, needs_grad=requires_grad), + mass_mat_L_bw=V(dtype=gs.ti_float, shape=(2, solver.n_dofs_, solver.n_dofs_, _B), needs_grad=requires_grad), + mass_mat_D_inv=V(dtype=gs.ti_float, shape=(solver.n_dofs_, _B), needs_grad=requires_grad), mass_mat_mask=V(dtype=gs.ti_bool, shape=(solver.n_entities_, _B)), mass_parent_mask=V(dtype=gs.ti_float, shape=(solver.n_dofs_, solver.n_dofs_)), - mass_mat=V(dtype=gs.ti_float, shape=mass_mat_shape), - mass_mat_L=V(dtype=gs.ti_float, shape=mass_mat_shape), substep_dt=V_SCALAR_FROM(dtype=gs.ti_float, value=solver._substep_dt), iterations=V_SCALAR_FROM(dtype=gs.ti_int, value=solver._options.iterations), tolerance=V_SCALAR_FROM(dtype=gs.ti_float, value=solver._options.tolerance), @@ -1136,6 +1143,7 @@ def get_dofs_info(solver): @DATA_ORIENTED class StructDofsState(metaclass=BASE_METACLASS): + # *_bw: Cache to avoid overwriting for backward pass force: V_ANNOTATION qf_bias: V_ANNOTATION qf_passive: V_ANNOTATION @@ -1145,8 +1153,11 @@ class StructDofsState(metaclass=BASE_METACLASS): pos: V_ANNOTATION vel: V_ANNOTATION vel_prev: V_ANNOTATION + vel_next: V_ANNOTATION acc: V_ANNOTATION + acc_bw: V_ANNOTATION acc_smooth: V_ANNOTATION + acc_smooth_bw: V_ANNOTATION qf_smooth: V_ANNOTATION qf_constraint: V_ANNOTATION cdof_ang: V_ANNOTATION @@ -1166,32 +1177,36 @@ class StructDofsState(metaclass=BASE_METACLASS): def get_dofs_state(solver): shape = (solver.n_dofs_, solver._B) + requires_grad = solver._requires_grad return StructDofsState( - force=V(dtype=gs.ti_float, shape=shape), - qf_bias=V(dtype=gs.ti_float, shape=shape), - qf_passive=V(dtype=gs.ti_float, shape=shape), - qf_actuator=V(dtype=gs.ti_float, shape=shape), - qf_applied=V(dtype=gs.ti_float, shape=shape), - act_length=V(dtype=gs.ti_float, shape=shape), - pos=V(dtype=gs.ti_float, shape=shape), - vel=V(dtype=gs.ti_float, shape=shape), - vel_prev=V(dtype=gs.ti_float, shape=shape), - acc=V(dtype=gs.ti_float, shape=shape), - acc_smooth=V(dtype=gs.ti_float, shape=shape), - qf_smooth=V(dtype=gs.ti_float, shape=shape), - qf_constraint=V(dtype=gs.ti_float, shape=shape), - cdof_ang=V(dtype=gs.ti_vec3, shape=shape), - cdof_vel=V(dtype=gs.ti_vec3, shape=shape), - cdofvel_ang=V(dtype=gs.ti_vec3, shape=shape), - cdofvel_vel=V(dtype=gs.ti_vec3, shape=shape), - cdofd_ang=V(dtype=gs.ti_vec3, shape=shape), - cdofd_vel=V(dtype=gs.ti_vec3, shape=shape), - f_vel=V(dtype=gs.ti_vec3, shape=shape), - f_ang=V(dtype=gs.ti_vec3, shape=shape), - ctrl_force=V(dtype=gs.ti_float, shape=shape), - ctrl_pos=V(dtype=gs.ti_float, shape=shape), - ctrl_vel=V(dtype=gs.ti_float, shape=shape), + force=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + qf_bias=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + qf_passive=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + qf_actuator=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + qf_applied=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + act_length=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + pos=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + vel=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + vel_prev=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + vel_next=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + acc=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + acc_bw=V(dtype=gs.ti_float, shape=(2, solver.n_dofs_, solver._B), needs_grad=requires_grad), + acc_smooth=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + acc_smooth_bw=V(dtype=gs.ti_float, shape=(2, solver.n_dofs_, solver._B), needs_grad=requires_grad), + qf_smooth=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + qf_constraint=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + cdof_ang=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cdof_vel=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cdofvel_ang=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cdofvel_vel=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cdofd_ang=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cdofd_vel=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + f_vel=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + f_ang=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + ctrl_force=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + ctrl_pos=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + ctrl_vel=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), ctrl_mode=V(dtype=gs.ti_int, shape=shape), hibernated=V(dtype=gs.ti_int, shape=shape), ) @@ -1202,6 +1217,7 @@ def get_dofs_state(solver): @DATA_ORIENTED class StructLinksState(metaclass=BASE_METACLASS): + # *_bw: Cache to avoid overwriting for backward pass cinr_inertial: V_ANNOTATION cinr_pos: V_ANNOTATION cinr_quat: V_ANNOTATION @@ -1214,16 +1230,24 @@ class StructLinksState(metaclass=BASE_METACLASS): cdd_ang: V_ANNOTATION pos: V_ANNOTATION quat: V_ANNOTATION + pos_bw: V_ANNOTATION + quat_bw: V_ANNOTATION i_pos: V_ANNOTATION + i_pos_bw: V_ANNOTATION i_quat: V_ANNOTATION j_pos: V_ANNOTATION j_quat: V_ANNOTATION + j_pos_bw: V_ANNOTATION + j_quat_bw: V_ANNOTATION j_vel: V_ANNOTATION j_ang: V_ANNOTATION cd_ang: V_ANNOTATION cd_vel: V_ANNOTATION + cd_ang_bw: V_ANNOTATION + cd_vel_bw: V_ANNOTATION mass_sum: V_ANNOTATION root_COM: V_ANNOTATION # COM of the kinematic tree + root_COM_bw: V_ANNOTATION mass_shift: V_ANNOTATION i_pos_shift: V_ANNOTATION cacc_ang: V_ANNOTATION @@ -1239,42 +1263,54 @@ class StructLinksState(metaclass=BASE_METACLASS): def get_links_state(solver): + max_n_joints_per_link = solver._static_rigid_sim_config.max_n_joints_per_link shape = (solver.n_links_, solver._B) + shape_bw = (solver.n_links_, max_n_joints_per_link + 1, solver._B) + + requires_grad = solver._requires_grad return StructLinksState( - cinr_inertial=V(dtype=gs.ti_mat3, shape=shape), - cinr_pos=V(dtype=gs.ti_vec3, shape=shape), - cinr_quat=V(dtype=gs.ti_vec4, shape=shape), - cinr_mass=V(dtype=gs.ti_float, shape=shape), - crb_inertial=V(dtype=gs.ti_mat3, shape=shape), - crb_pos=V(dtype=gs.ti_vec3, shape=shape), - crb_quat=V(dtype=gs.ti_vec4, shape=shape), - crb_mass=V(dtype=gs.ti_float, shape=shape), - cdd_vel=V(dtype=gs.ti_vec3, shape=shape), - cdd_ang=V(dtype=gs.ti_vec3, shape=shape), - pos=V(dtype=gs.ti_vec3, shape=shape), - quat=V(dtype=gs.ti_vec4, shape=shape), - i_pos=V(dtype=gs.ti_vec3, shape=shape), - i_quat=V(dtype=gs.ti_vec4, shape=shape), - j_pos=V(dtype=gs.ti_vec3, shape=shape), - j_quat=V(dtype=gs.ti_vec4, shape=shape), - j_vel=V(dtype=gs.ti_vec3, shape=shape), - j_ang=V(dtype=gs.ti_vec3, shape=shape), - cd_ang=V(dtype=gs.ti_vec3, shape=shape), - cd_vel=V(dtype=gs.ti_vec3, shape=shape), - mass_sum=V(dtype=gs.ti_float, shape=shape), - root_COM=V(dtype=gs.ti_vec3, shape=shape), - mass_shift=V(dtype=gs.ti_float, shape=shape), - i_pos_shift=V(dtype=gs.ti_vec3, shape=shape), - cacc_ang=V(dtype=gs.ti_vec3, shape=shape), - cacc_lin=V(dtype=gs.ti_vec3, shape=shape), - cfrc_ang=V(dtype=gs.ti_vec3, shape=shape), - cfrc_vel=V(dtype=gs.ti_vec3, shape=shape), - cfrc_applied_ang=V(dtype=gs.ti_vec3, shape=shape), - cfrc_applied_vel=V(dtype=gs.ti_vec3, shape=shape), - cfrc_coupling_ang=V(dtype=gs.ti_vec3, shape=shape), - cfrc_coupling_vel=V(dtype=gs.ti_vec3, shape=shape), - contact_force=V(dtype=gs.ti_vec3, shape=shape), + cinr_inertial=V(dtype=gs.ti_mat3, shape=shape, needs_grad=requires_grad), + cinr_pos=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cinr_quat=V(dtype=gs.ti_vec4, shape=shape, needs_grad=requires_grad), + cinr_mass=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + crb_inertial=V(dtype=gs.ti_mat3, shape=shape, needs_grad=requires_grad), + crb_pos=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + crb_quat=V(dtype=gs.ti_vec4, shape=shape, needs_grad=requires_grad), + crb_mass=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + cdd_vel=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cdd_ang=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + pos=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + quat=V(dtype=gs.ti_vec4, shape=shape, needs_grad=requires_grad), + pos_bw=V(dtype=gs.ti_vec3, shape=shape_bw, needs_grad=requires_grad), + quat_bw=V(dtype=gs.ti_vec4, shape=shape_bw, needs_grad=requires_grad), + i_pos=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + i_pos_bw=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + i_quat=V(dtype=gs.ti_vec4, shape=shape, needs_grad=requires_grad), + j_pos=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + j_quat=V(dtype=gs.ti_vec4, shape=shape, needs_grad=requires_grad), + j_pos_bw=V(dtype=gs.ti_vec3, shape=shape_bw, needs_grad=requires_grad), + j_quat_bw=V(dtype=gs.ti_vec4, shape=shape_bw, needs_grad=requires_grad), + j_vel=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + j_ang=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cd_ang=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cd_vel=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cd_ang_bw=V(dtype=gs.ti_vec3, shape=shape_bw, needs_grad=requires_grad), + cd_vel_bw=V(dtype=gs.ti_vec3, shape=shape_bw, needs_grad=requires_grad), + mass_sum=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + root_COM=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + root_COM_bw=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + mass_shift=V(dtype=gs.ti_float, shape=shape, needs_grad=requires_grad), + i_pos_shift=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cacc_ang=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cacc_lin=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cfrc_ang=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cfrc_vel=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cfrc_applied_ang=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cfrc_applied_vel=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cfrc_coupling_ang=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + cfrc_coupling_vel=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + contact_force=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), hibernated=V(dtype=gs.ti_int, shape=shape), ) @@ -1364,10 +1400,11 @@ class StructJointsState(metaclass=BASE_METACLASS): def get_joints_state(solver): shape = (solver.n_joints_, solver._B) + requires_grad = solver._requires_grad return StructJointsState( - xanchor=V(dtype=gs.ti_vec3, shape=shape), - xaxis=V(dtype=gs.ti_vec3, shape=shape), + xanchor=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), + xaxis=V(dtype=gs.ti_vec3, shape=shape, needs_grad=requires_grad), ) @@ -1719,6 +1756,30 @@ def get_entities_state(solver): ) +# =========================================== RigidAdjointCache =========================================== +@DATA_ORIENTED +class StructRigidAdjointCache(metaclass=BASE_METACLASS): + # This cache stores intermediate values during rigid body simulation to use Taichi's AD. Taichi's AD requires + # us not to overwrite the values that have been read during the forward pass, so we need to store the intemediate + # values in this cache to avoid overwriting them. Specifically, after we compute next frame's qpos, dofs_vel, and + # dofs_acc, we need to store them in this cache because we overwrite the values in the next frame. See how + # [kernel_save_adjoint_cache] is used in [rigid_solver_decomp.py] to store the values in this cache. + qpos: V_ANNOTATION + dofs_vel: V_ANNOTATION + dofs_acc: V_ANNOTATION + + +def get_rigid_adjoint_cache(solver): + substeps_local = solver._sim.substeps_local + requires_grad = solver._requires_grad + + return StructRigidAdjointCache( + qpos=V(dtype=gs.ti_float, shape=(substeps_local + 1, solver.n_qs_, solver._B), needs_grad=requires_grad), + dofs_vel=V(dtype=gs.ti_float, shape=(substeps_local + 1, solver.n_dofs_, solver._B), needs_grad=requires_grad), + dofs_acc=V(dtype=gs.ti_float, shape=(substeps_local + 1, solver.n_dofs_, solver._B), needs_grad=requires_grad), + ) + + # =================================== StructRigidSimStaticConfig =================================== @@ -1775,6 +1836,14 @@ def __init__(self, solver): self.entities_info = get_entities_info(solver) self.entities_state = get_entities_state(solver) + if solver._static_rigid_sim_config.requires_grad: + # Data structures required for backward pass + self.dofs_state_adjoint_cache = get_dofs_state(solver) + self.links_state_adjoint_cache = get_links_state(solver) + self.joints_state_adjoint_cache = get_joints_state(solver) + self.geoms_state_adjoint_cache = get_geoms_state(solver) + + self.rigid_adjoint_cache = get_rigid_adjoint_cache(solver) self.errno = V_SCALAR_FROM(dtype=gs.ti_int, value=0) @@ -1810,3 +1879,4 @@ def __init__(self, solver): SDFInfo = StructSDFInfo if gs.use_ndarray else ti.template() ContactIslandState = StructContactIslandState if gs.use_ndarray else ti.template() DiffContactInput = StructDiffContactInput if gs.use_ndarray else ti.template() +RigidAdjointCache = StructRigidAdjointCache if gs.use_ndarray else ti.template() diff --git a/genesis/utils/geom.py b/genesis/utils/geom.py index 16a1a1d6a..f387fae55 100644 --- a/genesis/utils/geom.py +++ b/genesis/utils/geom.py @@ -83,8 +83,12 @@ def ti_rotvec_to_R(rotvec, eps): def ti_rotvec_to_quat(rotvec, eps): quat = ti.Vector.zero(gs.ti_float, 4) - theta = rotvec.norm() - if theta > eps: + # We need to use [norm_sqr] instead of [norm] to avoid nan gradients in the backward pass. Even when theta = 0, + # the gradient of [norm] operation is computed and used (note that the gradient becomes NaN when theta = 0). This + # is seemd to be a bug in Taichi autodiff @TODO: change back after the bug is fixed. + thetasq = rotvec.norm_sqr() + if thetasq > (eps**2): + theta = ti.sqrt(thetasq) theta_half = 0.5 * theta c, s = ti.cos(theta_half), ti.sin(theta_half) diff --git a/genesis/utils/path_planning.py b/genesis/utils/path_planning.py index 4cd4a3c4e..40d82198c 100644 --- a/genesis/utils/path_planning.py +++ b/genesis/utils/path_planning.py @@ -429,6 +429,7 @@ def _kernel_rrt_step1( entities_info, rigid_global_info, self._solver._static_rigid_sim_config, + is_backward=False, ) gs.engine.solvers.rigid.rigid_solver_decomp.func_update_geoms( i_b, @@ -439,6 +440,7 @@ def _kernel_rrt_step1( rigid_global_info, self._solver._static_rigid_sim_config, force_update_fixed_geoms=False, + is_backward=False, ) @ti.kernel @@ -797,6 +799,7 @@ def _kernel_rrt_connect_step1( entities_info, rigid_global_info, self._solver._static_rigid_sim_config, + is_backward=False, ) gs.engine.solvers.rigid.rigid_solver_decomp.func_update_geoms( i_b, @@ -807,6 +810,7 @@ def _kernel_rrt_connect_step1( rigid_global_info, self._solver._static_rigid_sim_config, force_update_fixed_geoms=False, + is_backward=False, ) @ti.kernel diff --git a/tests/test_grad.py b/tests/test_grad.py index f532e4602..102bd6702 100644 --- a/tests/test_grad.py +++ b/tests/test_grad.py @@ -193,6 +193,7 @@ def constraint_solver_resolve(): # Step once to compute constraint solver's inputs: [mass], [jac], [aref], [efc_D], [force]. We do not call the # entire scene.step() because it will overwrite the necessary information that we need to compute the gradients. kernel_step_1( + f=0, links_state=rigid_solver.links_state, links_info=rigid_solver.links_info, joints_state=rigid_solver.joints_state, @@ -204,8 +205,10 @@ def constraint_solver_resolve(): entities_state=rigid_solver.entities_state, entities_info=rigid_solver.entities_info, rigid_global_info=rigid_solver._rigid_global_info, + rigid_adjoint_cache=rigid_solver._rigid_adjoint_cache, static_rigid_sim_config=rigid_solver._static_rigid_sim_config, contact_island_state=constraint_solver.contact_island.contact_island_state, + is_backward=False, ) constraint_solver.add_equality_constraints() rigid_solver.collider.detection() @@ -264,7 +267,7 @@ def compute_loss(input_mass, input_jac, input_aref, input_efc_D, input_force): ### Compute directional derivatives along random directions FD_EPS = 1e-3 - TRIALS = 100 + TRIALS = 200 for dL_dx, x_type in ( (dL_dforce, "force"), @@ -408,3 +411,91 @@ def test_differentiable_push(precision, show_viewer): for v_i in v_list[:-1]: assert (v_i.grad.abs() > gs.EPS).any() assert (v_list[-1].grad.abs() < gs.EPS).all() + + +@pytest.mark.required +@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) +def test_differentiable_rigid(show_viewer): + dt = 1e-2 + horizon = 100 + substeps = 1 + goal_pos = gs.tensor([0.7, 1.0, 0.05]) + goal_quat = gs.tensor([0.3, 0.2, 0.1, 0.9]) + goal_quat = goal_quat / torch.norm(goal_quat, dim=-1, keepdim=True) + + scene = gs.Scene( + sim_options=gs.options.SimOptions(dt=dt, substeps=substeps, requires_grad=True, gravity=(0, 0, -1)), + rigid_options=gs.options.RigidOptions( + enable_collision=False, + enable_self_collision=False, + enable_joint_limit=False, + disable_constraint=True, + use_contact_island=False, + use_hibernation=False, + ), + viewer_options=gs.options.ViewerOptions( + camera_pos=(2.5, -0.15, 2.42), + camera_lookat=(0.5, 0.5, 0.1), + ), + show_viewer=show_viewer, + ) + + box = scene.add_entity( + gs.morphs.Box( + pos=(0, 0, 0), + size=(0.1, 0.1, 0.2), + ), + surface=gs.surfaces.Default( + color=(0.9, 0.0, 0.0, 1.0), + ), + ) + if show_viewer: + target = scene.add_entity( + gs.morphs.Box( + pos=goal_pos, + quat=goal_quat, + size=(0.1, 0.1, 0.2), + ), + surface=gs.surfaces.Default( + color=(0.0, 0.9, 0.0, 0.5), + ), + ) + + scene.build() + + num_iter = 200 + lr = 1e-2 + + init_pos = gs.tensor([0.3, 0.1, 0.28], requires_grad=True) + init_quat = gs.tensor([1.0, 0.0, 0.0, 0.0], requires_grad=True) + optimizer = torch.optim.Adam([init_pos, init_quat], lr=lr) + + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_iter, eta_min=1e-3) + + for iter in range(num_iter): + scene.reset() + + box.set_pos(init_pos) + box.set_quat(init_quat) + + loss = 0 + for i in range(horizon): + scene.step() + if show_viewer: + target.set_pos(goal_pos) + target.set_quat(goal_quat) + + box_state = box.get_state() + box_pos = box_state.pos + box_quat = box_state.quat + loss = torch.abs(box_pos - goal_pos).sum() + torch.abs(box_quat - goal_quat).sum() + + optimizer.zero_grad() + loss.backward() # this lets gradient flow all the way back to tensor input + optimizer.step() + scheduler.step() + + with torch.no_grad(): + init_quat.data = init_quat / torch.norm(init_quat, dim=-1, keepdim=True) + + assert_allclose(loss, 0.0, atol=1e-2) From d6f79bc3a3925f775800de5716144e6fda21bd90 Mon Sep 17 00:00:00 2001 From: SonSang Date: Tue, 25 Nov 2025 10:30:00 -0500 Subject: [PATCH 02/12] fix minor bug --- genesis/engine/solvers/rigid/rigid_solver_decomp.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/genesis/engine/solvers/rigid/rigid_solver_decomp.py b/genesis/engine/solvers/rigid/rigid_solver_decomp.py index 4dc5aaedf..423118ac4 100644 --- a/genesis/engine/solvers/rigid/rigid_solver_decomp.py +++ b/genesis/engine/solvers/rigid/rigid_solver_decomp.py @@ -5642,13 +5642,12 @@ def func_update_geoms( """ NOTE: this only update geom pose, not its verts and else. """ - n_geoms = geoms_info.pos.shape[0] for i_0 in ( ( # Dynamic inner loop for forward pass range(rigid_global_info.n_awake_entities[i_b]) if ti.static(static_rigid_sim_config.use_hibernation) - else range(n_geoms) + else range(geoms_info.pos.shape[0]) ) if ti.static(not is_backward) else ( From 15ad116140b1b9c73e46e7f5d45e2f3f80f3d98d Mon Sep 17 00:00:00 2001 From: SonSang Date: Tue, 25 Nov 2025 14:27:22 -0500 Subject: [PATCH 03/12] fix minor error (gymbal lock issue) in unit test --- tests/test_rigid_physics.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_rigid_physics.py b/tests/test_rigid_physics.py index 8f7375e00..7d72dd9e5 100644 --- a/tests/test_rigid_physics.py +++ b/tests/test_rigid_physics.py @@ -1292,9 +1292,11 @@ def test_set_root_pose(batch_fixed_verts, relative, show_viewer, tol): ): pos_zero = torch.tensor(pos_zero, device=gs.device, dtype=gs.tc_float) euler_zero = torch.deg2rad(torch.tensor(euler_zero, dtype=gs.tc_float)) + quat_zero = gu.xyz_to_quat(euler_zero, rpy=True) assert_allclose(entity.get_pos(), pos_zero, tol=tol) - euler = gu.quat_to_xyz(entity.get_quat(), rpy=True) - assert_allclose(euler, euler_zero, tol=5e-4) + # Use quaternion for comparison to avoid gymbal lock issue in euler angles + quat = entity.get_quat() + assert_allclose(quat, quat_zero, tol=tol) base_aabb = entity.geoms[0].get_AABB() assert base_aabb.shape == ((2, 2, 3) if not entity.geoms[0].is_fixed or batch_fixed_verts else (2, 3)) assert_allclose(base_aabb, base_aabb_init, tol=tol) @@ -1311,14 +1313,12 @@ def test_set_root_pose(batch_fixed_verts, relative, show_viewer, tol): quat_delta = torch.tile(torch.as_tensor(np.random.rand(4), dtype=gs.tc_float, device=gs.device), (2, 1)) quat_delta /= torch.linalg.norm(quat_delta) entity.set_quat(quat_delta, relative=relative) - euler = gu.quat_to_xyz(entity.get_quat(), rpy=True) - quat_zero = gu.xyz_to_quat(euler_zero, rpy=True) + quat = entity.get_quat() if relative: quat_ref = gu.transform_quat_by_quat(quat_zero, quat_delta) else: quat_ref = quat_delta - euler_ref = gu.quat_to_xyz(quat_ref, rpy=True) - assert_allclose(euler, euler_ref, tol=tol) + assert_allclose(quat, quat_ref, tol=tol) @pytest.mark.required From 4a5fb8286bab1190edae3a36982b2f300f53d88b Mon Sep 17 00:00:00 2001 From: Sanghyun Date: Wed, 26 Nov 2025 18:28:09 -0800 Subject: [PATCH 04/12] speedup by removing f and adjoint cache parameter from kernels --- .../solvers/rigid/rigid_solver_decomp.py | 59 +++++++++---------- tests/test_grad.py | 2 - 2 files changed, 28 insertions(+), 33 deletions(-) diff --git a/genesis/engine/solvers/rigid/rigid_solver_decomp.py b/genesis/engine/solvers/rigid/rigid_solver_decomp.py index 423118ac4..c1aa999de 100644 --- a/genesis/engine/solvers/rigid/rigid_solver_decomp.py +++ b/genesis/engine/solvers/rigid/rigid_solver_decomp.py @@ -921,8 +921,16 @@ def substep(self, f): self._links_state_cache.clear() + if self._requires_grad and f == 0: + kernel_save_adjoint_cache( + f=f, + dofs_state=self.dofs_state, + rigid_global_info=self._rigid_global_info, + rigid_adjoint_cache=self._rigid_adjoint_cache, + static_rigid_sim_config=self._static_rigid_sim_config, + ) + kernel_step_1( - f=f, links_state=self.links_state, links_info=self.links_info, joints_state=self.joints_state, @@ -934,7 +942,6 @@ def substep(self, f): entities_state=self.entities_state, entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, - rigid_adjoint_cache=self._rigid_adjoint_cache, static_rigid_sim_config=self._static_rigid_sim_config, contact_island_state=self.constraint_solver.contact_island.contact_island_state, is_backward=False, @@ -950,7 +957,6 @@ def substep(self, f): else: self._func_constraint_force() kernel_step_2( - f=f, dofs_state=self.dofs_state, dofs_info=self.dofs_info, links_info=self.links_info, @@ -963,11 +969,18 @@ def substep(self, f): geoms_state=self.geoms_state, collider_state=self.collider._collider_state, rigid_global_info=self._rigid_global_info, - rigid_adjoint_cache=self._rigid_adjoint_cache, static_rigid_sim_config=self._static_rigid_sim_config, contact_island_state=self.constraint_solver.contact_island.contact_island_state, is_backward=False, ) + if self._requires_grad: + kernel_save_adjoint_cache( + f=f + 1, + dofs_state=self.dofs_state, + rigid_global_info=self._rigid_global_info, + rigid_adjoint_cache=self._rigid_adjoint_cache, + static_rigid_sim_config=self._static_rigid_sim_config, + ) def check_errno(self): # Note that errno must be evaluated BEFORE match because otherwise it will be evaluated for each case... @@ -1295,7 +1308,6 @@ def substep_pre_coupling_grad(self, f): gs.raise_exception(f"Nan grad in qpos or dofs_vel found at step {self._sim.cur_step_global}") kernel_step_2.grad( - f=f, dofs_state=self.dofs_state, dofs_info=self.dofs_info, links_info=self.links_info, @@ -1308,7 +1320,6 @@ def substep_pre_coupling_grad(self, f): geoms_state=self.geoms_state, collider_state=self.collider._collider_state, rigid_global_info=self._rigid_global_info, - rigid_adjoint_cache=self._rigid_adjoint_cache, static_rigid_sim_config=self._static_rigid_sim_config, contact_island_state=self.constraint_solver.contact_island.contact_island_state, is_backward=True, @@ -1384,7 +1395,6 @@ def substep_post_coupling(self, f): is_backward=False, ) kernel_step_2( - f=f, dofs_state=self.dofs_state, dofs_info=self.dofs_info, links_info=self.links_info, @@ -1397,7 +1407,6 @@ def substep_post_coupling(self, f): geoms_state=self.geoms_state, collider_state=self.collider._collider_state, rigid_global_info=self._rigid_global_info, - rigid_adjoint_cache=self._rigid_adjoint_cache, static_rigid_sim_config=self._static_rigid_sim_config, contact_island_state=self.constraint_solver.contact_island.contact_island_state, is_backward=False, @@ -4566,7 +4575,6 @@ def func_update_cartesian_space( @ti.kernel(fastcache=gs.use_fastcache) def kernel_step_1( - f: ti.int32, links_state: array_class.LinksState, links_info: array_class.LinksInfo, joints_state: array_class.JointsState, @@ -4578,21 +4586,10 @@ def kernel_step_1( entities_state: array_class.EntitiesState, entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, - rigid_adjoint_cache: array_class.RigidAdjointCache, static_rigid_sim_config: ti.template(), contact_island_state: array_class.ContactIslandState, is_backward: ti.template(), ): - if ti.static(static_rigid_sim_config.requires_grad): - if f == 0: - func_save_adjoint_cache( - f=f, - dofs_state=dofs_state, - rigid_global_info=rigid_global_info, - rigid_adjoint_cache=rigid_adjoint_cache, - static_rigid_sim_config=static_rigid_sim_config, - ) - if ti.static(static_rigid_sim_config.enable_mujoco_compatibility): _B = links_state.pos.shape[1] ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) @@ -4704,7 +4701,6 @@ def func_implicit_damping( @ti.kernel(fastcache=gs.use_fastcache) def kernel_step_2( - f: ti.int32, dofs_state: array_class.DofsState, dofs_info: array_class.DofsInfo, links_info: array_class.LinksInfo, @@ -4717,7 +4713,6 @@ def kernel_step_2( geoms_state: array_class.GeomsState, collider_state: array_class.ColliderState, rigid_global_info: array_class.RigidGlobalInfo, - rigid_adjoint_cache: array_class.RigidAdjointCache, static_rigid_sim_config: ti.template(), contact_island_state: array_class.ContactIslandState, is_backward: ti.template(), @@ -4805,15 +4800,6 @@ def kernel_step_2( is_backward=is_backward, ) - if ti.static(static_rigid_sim_config.requires_grad): - func_save_adjoint_cache( - f=f + 1, - dofs_state=dofs_state, - rigid_global_info=rigid_global_info, - rigid_adjoint_cache=rigid_adjoint_cache, - static_rigid_sim_config=static_rigid_sim_config, - ) - @ti.kernel(fastcache=gs.use_fastcache) def kernel_forward_kinematics_links_geoms( @@ -6807,6 +6793,17 @@ def func_copy_next_to_curr_grad( rigid_global_info.qpos[i_q, i_b] = rigid_adjoint_cache.qpos[f, i_q, i_b] +@ti.kernel(fastcache=gs.use_fastcache) +def kernel_save_adjoint_cache( + f: ti.int32, + dofs_state: array_class.DofsState, + rigid_global_info: array_class.RigidGlobalInfo, + rigid_adjoint_cache: array_class.RigidAdjointCache, + static_rigid_sim_config: ti.template(), +): + func_save_adjoint_cache(f, dofs_state, rigid_global_info, rigid_adjoint_cache, static_rigid_sim_config) + + @ti.func def func_save_adjoint_cache( f: ti.int32, diff --git a/tests/test_grad.py b/tests/test_grad.py index 102bd6702..5dbe31c51 100644 --- a/tests/test_grad.py +++ b/tests/test_grad.py @@ -193,7 +193,6 @@ def constraint_solver_resolve(): # Step once to compute constraint solver's inputs: [mass], [jac], [aref], [efc_D], [force]. We do not call the # entire scene.step() because it will overwrite the necessary information that we need to compute the gradients. kernel_step_1( - f=0, links_state=rigid_solver.links_state, links_info=rigid_solver.links_info, joints_state=rigid_solver.joints_state, @@ -205,7 +204,6 @@ def constraint_solver_resolve(): entities_state=rigid_solver.entities_state, entities_info=rigid_solver.entities_info, rigid_global_info=rigid_solver._rigid_global_info, - rigid_adjoint_cache=rigid_solver._rigid_adjoint_cache, static_rigid_sim_config=rigid_solver._static_rigid_sim_config, contact_island_state=constraint_solver.contact_island.contact_island_state, is_backward=False, From 4c4843bce160dda90d1112539fd079e18f1b0c05 Mon Sep 17 00:00:00 2001 From: Sanghyun Date: Wed, 26 Nov 2025 18:57:59 -0800 Subject: [PATCH 05/12] speedup by removing is_backward from kernels --- .../solvers/rigid/constraint_solver_decomp.py | 1 - .../solvers/rigid/rigid_solver_decomp.py | 375 ++++++++---------- genesis/utils/array_class.py | 1 + 3 files changed, 176 insertions(+), 201 deletions(-) diff --git a/genesis/engine/solvers/rigid/constraint_solver_decomp.py b/genesis/engine/solvers/rigid/constraint_solver_decomp.py index a4b0a76e3..ec7cb9ffd 100644 --- a/genesis/engine/solvers/rigid/constraint_solver_decomp.py +++ b/genesis/engine/solvers/rigid/constraint_solver_decomp.py @@ -1994,7 +1994,6 @@ def func_update_gradient( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=False, ) elif ti.static(static_rigid_sim_config.solver_type == gs.constraint_solver.Newton): diff --git a/genesis/engine/solvers/rigid/rigid_solver_decomp.py b/genesis/engine/solvers/rigid/rigid_solver_decomp.py index c1aa999de..60d9263bd 100644 --- a/genesis/engine/solvers/rigid/rigid_solver_decomp.py +++ b/genesis/engine/solvers/rigid/rigid_solver_decomp.py @@ -268,6 +268,7 @@ def build(self): sparse_solve=self._options.sparse_solve, integrator=self._integrator, solver_type=self._options.constraint_solver, + is_backward=False, ) else: self._static_rigid_sim_config = array_class.StructRigidSimStaticConfig( @@ -285,6 +286,7 @@ def build(self): sparse_solve=False, integrator=gs.integrator.approximate_implicitfast, solver_type=gs.constraint_solver.CG, + is_backward=False, ) if self._static_rigid_sim_config.requires_grad: @@ -410,7 +412,6 @@ def _init_invweight_and_meaninertia(self, envs_idx=None, *, force_update=True, u rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, decompose=True, - is_backward=False, ) # Define some proxies for convenience @@ -944,7 +945,6 @@ def substep(self, f): rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, contact_island_state=self.constraint_solver.contact_island.contact_island_state, - is_backward=False, ) if isinstance(self.sim.coupler, SAPCoupler): @@ -952,7 +952,6 @@ def substep(self, f): dofs_state=self.dofs_state, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, - is_backward=False, ) else: self._func_constraint_force() @@ -971,7 +970,6 @@ def substep(self, f): rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, contact_island_state=self.constraint_solver.contact_island.contact_island_state, - is_backward=False, ) if self._requires_grad: kernel_save_adjoint_cache( @@ -1048,7 +1046,6 @@ def _func_forward_dynamics(self): rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, contact_island_state=self.constraint_solver.contact_island.contact_island_state, - is_backward=False, ) def _func_update_acc(self): @@ -1059,7 +1056,6 @@ def _func_update_acc(self): entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, - is_backward=False, ) def _func_forward_kinematics_entity(self, i_e, envs_idx): @@ -1075,7 +1071,6 @@ def _func_forward_kinematics_entity(self, i_e, envs_idx): entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, - is_backward=False, ) def _func_integrate_dq_entity(self, dq, i_e, i_b, respect_joint_limit): @@ -1092,7 +1087,7 @@ def _func_integrate_dq_entity(self, dq, i_e, i_b, respect_joint_limit): static_rigid_sim_config=self._static_rigid_sim_config, ) - def _func_update_geoms(self, envs_idx, *, force_update_fixed_geoms=False, is_backward=False): + def _func_update_geoms(self, envs_idx, *, force_update_fixed_geoms=False): kernel_update_geoms( envs_idx, entities_info=self.entities_info, @@ -1102,7 +1097,6 @@ def _func_update_geoms(self, envs_idx, *, force_update_fixed_geoms=False, is_bac rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, force_update_fixed_geoms=force_update_fixed_geoms, - is_backward=is_backward, ) def _process_dim(self, tensor, envs_idx=None): @@ -1244,6 +1238,9 @@ def substep_pre_coupling(self, f): self.substep(f) def substep_pre_coupling_grad(self, f): + # Change to backward mode + self._static_rigid_sim_config.is_backward = True + # Run forward substep again to restore this step's information, this is needed because we do not store info # of every substep. kernel_prepare_backward_substep( @@ -1282,7 +1279,6 @@ def substep_pre_coupling_grad(self, f): rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, force_update_fixed_geoms=False, - is_backward=True, ) is_grad_valid = kernel_begin_backward_substep( @@ -1322,7 +1318,6 @@ def substep_pre_coupling_grad(self, f): rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, contact_island_state=self.constraint_solver.contact_island.contact_island_state, - is_backward=True, ) # We cannot use [kernel_forward_dynamics.grad] because we read [dofs_state.acc] and overwrite it in the kernel, @@ -1339,7 +1334,6 @@ def substep_pre_coupling_grad(self, f): entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, - is_backward=True, ) kernel_copy_acc( f=f, @@ -1360,7 +1354,6 @@ def substep_pre_coupling_grad(self, f): rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, contact_island_state=self.constraint_solver.contact_island.contact_island_state, - is_backward=True, ) # If it was the very first substep, we need to backpropagate through the initial update of the cartesian space @@ -1378,9 +1371,11 @@ def substep_pre_coupling_grad(self, f): rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, force_update_fixed_geoms=False, - is_backward=True, ) + # Change back to forward mode + self._static_rigid_sim_config.is_backward = False + def substep_post_coupling(self, f): from genesis.engine.couplers import SAPCoupler, IPCCoupler @@ -1392,7 +1387,6 @@ def substep_post_coupling(self, f): dofs_state=self.dofs_state, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, - is_backward=False, ) kernel_step_2( dofs_state=self.dofs_state, @@ -1409,7 +1403,6 @@ def substep_post_coupling(self, f): rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, contact_island_state=self.constraint_solver.contact_island.contact_island_state, - is_backward=False, ) elif isinstance(self.sim.coupler, IPCCoupler): # For IPCCoupler, perform full rigid body computation in post-coupling phase @@ -1560,7 +1553,6 @@ def set_state(self, f, state, envs_idx=None): entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, - is_backward=False, ) self._errno[None] = 0 @@ -1613,7 +1605,6 @@ def load_ckpt(self, ckpt_name): rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, force_update_fixed_geoms=False, - is_backward=False, ) for entity in self._entities: @@ -1879,7 +1870,6 @@ def set_base_links_pos( entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, - is_backward=False, ) def set_base_links_pos_grad(self, links_idx, envs_idx, relative, unsafe, pos_grad): @@ -1960,7 +1950,6 @@ def set_base_links_quat( entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, - is_backward=False, ) def set_base_links_quat_grad(self, links_idx, envs_idx, relative, unsafe, quat_grad): @@ -2068,7 +2057,6 @@ def set_qpos(self, qpos, qs_idx=None, envs_idx=None, *, skip_forward=False, unsa entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, - is_backward=False, ) def set_global_sol_params(self, sol_params, *, unsafe=False): @@ -2257,7 +2245,6 @@ def set_dofs_velocity(self, velocity, dofs_idx=None, envs_idx=None, *, skip_forw entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, - is_backward=False, ) def set_dofs_velocity_grad(self, dofs_idx, envs_idx, unsafe, velocity_grad): @@ -2309,7 +2296,6 @@ def set_dofs_position(self, position, dofs_idx=None, envs_idx=None, *, skip_forw entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, - is_backward=False, ) def control_dofs_force(self, force, dofs_idx=None, envs_idx=None, *, unsafe=False): @@ -2951,7 +2937,6 @@ def update_qacc_from_qvel_delta( dofs_state: array_class.DofsState, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), - is_backward: ti.template(), ): n_dofs = dofs_state.ctrl_mode.shape[0] _B = dofs_state.ctrl_mode.shape[1] @@ -2964,7 +2949,7 @@ def update_qacc_from_qvel_delta( if ti.static(static_rigid_sim_config.use_hibernation) else range(1) ) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ( # Static inner loop for backward pass ti.static(range(static_rigid_sim_config.max_n_awake_dofs)) @@ -2989,7 +2974,6 @@ def update_qvel( dofs_state: array_class.DofsState, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), - is_backward: ti.template(), ): _B = dofs_state.vel.shape[1] n_dofs = dofs_state.vel.shape[0] @@ -3003,7 +2987,7 @@ def update_qvel( if ti.static(static_rigid_sim_config.use_hibernation) else range(1) ) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ( # Static inner loop for backward pass ti.static(range(static_rigid_sim_config.max_n_awake_dofs)) @@ -3034,7 +3018,6 @@ def kernel_compute_mass_matrix( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), decompose: ti.template(), - is_backward: ti.template(), ): func_compute_mass_matrix( implicit_damping=False, @@ -3045,7 +3028,6 @@ def kernel_compute_mass_matrix( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, ) if decompose: func_factor_mass( @@ -3055,7 +3037,6 @@ def kernel_compute_mass_matrix( dofs_info=dofs_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, ) @@ -3622,7 +3603,6 @@ def kernel_forward_dynamics( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), contact_island_state: array_class.ContactIslandState, - is_backward: ti.template(), ): func_forward_dynamics( links_state=links_state, @@ -3636,7 +3616,6 @@ def kernel_forward_dynamics( rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, contact_island_state=contact_island_state, - is_backward=is_backward, ) @@ -3648,7 +3627,6 @@ def kernel_update_acc( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), - is_backward: ti.template(), ): func_update_acc( update_cacc=True, @@ -3658,7 +3636,6 @@ def kernel_update_acc( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, ) @@ -3683,7 +3660,6 @@ def func_compute_mass_matrix( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), - is_backward: ti.template(), ): # crb initialize ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) @@ -3699,7 +3675,7 @@ def func_compute_mass_matrix( if ti.static(static_rigid_sim_config.use_hibernation) else range(1) ) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ( # Static inner loop for backward pass ti.static(range(static_rigid_sim_config.max_n_awake_links)) @@ -3735,7 +3711,7 @@ def func_compute_mass_matrix( if ti.static(static_rigid_sim_config.use_hibernation) else ti.static(range(1)) ) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ( # Static inner loop for backward pass ti.static(range(static_rigid_sim_config.max_n_awake_entities)) @@ -3754,7 +3730,7 @@ def func_compute_mass_matrix( for i in ( range(entities_info.n_links[i_e]) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) ): if i < entities_info.n_links[i_e]: @@ -3782,7 +3758,7 @@ def func_compute_mass_matrix( if ti.static(static_rigid_sim_config.use_hibernation) else range(1) ) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ( # Static inner loop for backward pass ti.static(range(static_rigid_sim_config.max_n_awake_links)) @@ -3802,10 +3778,12 @@ def func_compute_mass_matrix( for i_d_ in ( range(links_info.dof_start[I_l], links_info.dof_end[I_l]) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_dofs_per_link)) ): - i_d = i_d_ if ti.static(not is_backward) else links_info.dof_start[I_l] + i_d_ + i_d = ( + i_d_ if ti.static(not static_rigid_sim_config.is_backward) else links_info.dof_start[I_l] + i_d_ + ) if i_d < links_info.dof_end[I_l]: dofs_state.f_ang[i_d, i_b], dofs_state.f_vel[i_d, i_b] = gu.inertial_mul( @@ -3829,7 +3807,7 @@ def func_compute_mass_matrix( if ti.static(static_rigid_sim_config.use_hibernation) else range(1) ) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ( # Static inner loop for backward pass ti.static(range(static_rigid_sim_config.max_n_awake_entities)) @@ -3854,7 +3832,7 @@ def func_compute_mass_matrix( (entities_info.dof_start[i_e], entities_info.dof_end[i_e]), ) ) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ( # Static inner loop for backward pass ti.static( @@ -3865,8 +3843,16 @@ def func_compute_mass_matrix( ) ) ): - i_d = i_d_ if ti.static(not is_backward) else entities_info.dof_start[i_e] + i_d_ - j_d = j_d_ if ti.static(not is_backward) else entities_info.dof_start[i_e] + j_d_ + i_d = ( + i_d_ + if ti.static(not static_rigid_sim_config.is_backward) + else entities_info.dof_start[i_e] + i_d_ + ) + j_d = ( + j_d_ + if ti.static(not static_rigid_sim_config.is_backward) + else entities_info.dof_start[i_e] + j_d_ + ) if i_d < entities_info.dof_end[i_e] and j_d < entities_info.dof_end[i_e]: rigid_global_info.mass_mat[i_d, j_d, i_b] = ( @@ -3874,7 +3860,7 @@ def func_compute_mass_matrix( + dofs_state.f_vel[i_d, i_b].dot(dofs_state.cdof_vel[j_d, i_b]) ) * rigid_global_info.mass_parent_mask[i_d, j_d] - if ti.static(not is_backward): + if ti.static(not static_rigid_sim_config.is_backward): for i_d in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): for j_d in range(i_d + 1, entities_info.dof_end[i_e]): rigid_global_info.mass_mat[i_d, j_d, i_b] = rigid_global_info.mass_mat[j_d, i_d, i_b] @@ -3919,12 +3905,11 @@ def func_factor_mass( dofs_info: array_class.DofsInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), - is_backward: ti.template(), ): """ Compute Cholesky decomposition (L^T @ D @ L) of mass matrix. """ - if ti.static(not is_backward): + if ti.static(not static_rigid_sim_config.is_backward): _B = dofs_state.ctrl_mode.shape[1] n_entities = entities_info.n_links.shape[0] @@ -3939,7 +3924,7 @@ def func_factor_mass( if ti.static(static_rigid_sim_config.use_hibernation) else range(1) ) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ( # Static inner loop for backward pass ti.static(range(static_rigid_sim_config.max_n_awake_entities)) @@ -3963,18 +3948,26 @@ def func_factor_mass( for i_d_ in ( range(entity_dof_start, entity_dof_end) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) ): - i_d = i_d_ if ti.static(not is_backward) else entities_info.dof_start[i_e] + i_d_ + i_d = ( + i_d_ + if ti.static(not static_rigid_sim_config.is_backward) + else entities_info.dof_start[i_e] + i_d_ + ) if i_d < entity_dof_end: for j_d_ in ( range(entity_dof_start, i_d + 1) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) ): - j_d = j_d_ if ti.static(not is_backward) else entities_info.dof_start[i_e] + j_d_ + j_d = ( + j_d_ + if ti.static(not static_rigid_sim_config.is_backward) + else entities_info.dof_start[i_e] + j_d_ + ) if j_d < i_d + 1: rigid_global_info.mass_mat_L[i_d, j_d, i_b] = rigid_global_info.mass_mat[ @@ -3996,7 +3989,7 @@ def func_factor_mass( for i_d_ in ( range(n_dofs) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) ): if i_d_ < n_dofs: @@ -4007,7 +4000,7 @@ def func_factor_mass( for j_d_ in ( range(i_d - entity_dof_start) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) ): if j_d_ < i_d - entity_dof_start: @@ -4019,12 +4012,12 @@ def func_factor_mass( for k_d_ in ( range(entity_dof_start, j_d + 1) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) ): k_d = ( k_d_ - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else entities_info.dof_start[i_e] + k_d_ ) if k_d < j_d + 1: @@ -4052,7 +4045,7 @@ def func_factor_mass( for i_d0 in ( range(n_dofs) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) ): if i_d0 < n_dofs: @@ -4060,10 +4053,14 @@ def func_factor_mass( i_pr = (entity_dof_start + entity_dof_end - 1) - i_d for j_d_ in ( range(entity_dof_start, i_d + 1) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) ): - j_d = j_d_ if ti.static(not is_backward) else (j_d_ + entities_info.dof_start[i_e]) + j_d = ( + j_d_ + if ti.static(not static_rigid_sim_config.is_backward) + else (j_d_ + entities_info.dof_start[i_e]) + ) j_pr = (entity_dof_start + entity_dof_end - 1) - j_d if j_d < i_d + 1: rigid_global_info.mass_mat_L_bw[0, i_pr, j_pr, i_b] = rigid_global_info.mass_mat[ @@ -4091,12 +4088,12 @@ def func_factor_mass( # https://en.wikipedia.org/wiki/Cholesky_decomposition for p_i0 in ( range(n_dofs) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) ): for p_j0 in ( range(p_i0 + 1) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) ): if p_i0 < n_dofs and p_j0 < n_dofs and p_j0 <= p_i0: @@ -4107,7 +4104,7 @@ def func_factor_mass( sum = gs.ti_float(0.0) for p_k0 in ( range(p_j0) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) ): # k_pr < j_pr @@ -4129,12 +4126,12 @@ def func_factor_mass( for i_d0 in ( range(n_dofs) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) ): for i_d1 in ( range(i_d0 + 1) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) ): if i_d0 < n_dofs and i_d1 < n_dofs and i_d1 <= i_d0: @@ -4163,7 +4160,6 @@ def func_solve_mass_batched( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), - is_backward: ti.template(), ): # This loop is considered an inner loop ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) @@ -4174,7 +4170,7 @@ def func_solve_mass_batched( if ti.static(static_rigid_sim_config.use_hibernation) else range(entities_info.n_links.shape[0]) ) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ( # Static inner loop for backward pass ti.static(range(static_rigid_sim_config.max_n_awake_entities)) @@ -4203,26 +4199,30 @@ def func_solve_mass_batched( # Step 1: Solve w st. L^T @ w = y for i_d_ in ( range(n_dofs) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) ): if i_d_ < n_dofs: i_d = entity_dof_end - i_d_ - 1 - if ti.static(is_backward): + if ti.static(static_rigid_sim_config.is_backward): out_bw[0, i_d, i_b] = vec[i_d, i_b] else: out[i_d, i_b] = vec[i_d, i_b] for j_d_ in ( range(i_d + 1, entity_dof_end) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) ): - j_d = j_d_ if ti.static(not is_backward) else (j_d_ + entities_info.dof_start[i_e]) + j_d = ( + j_d_ + if ti.static(not static_rigid_sim_config.is_backward) + else (j_d_ + entities_info.dof_start[i_e]) + ) if j_d >= i_d + 1 and j_d < entity_dof_end: # Since we read out[j_d, i_b], and j_d > i_d, which means that out[j_d, i_b] is already # finalized at this point, we don't need to care about AD mutation rule. - if ti.static(is_backward): + if ti.static(static_rigid_sim_config.is_backward): out_bw[0, i_d, i_b] += -( rigid_global_info.mass_mat_L[j_d, i_d, i_b] * out_bw[0, j_d, i_b] ) @@ -4232,12 +4232,16 @@ def func_solve_mass_batched( # Step 2: z = D^{-1} w for i_d_ in ( range(entity_dof_start, entity_dof_end) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) ): - i_d = i_d_ if ti.static(not is_backward) else (i_d_ + entities_info.dof_start[i_e]) + i_d = ( + i_d_ + if ti.static(not static_rigid_sim_config.is_backward) + else (i_d_ + entities_info.dof_start[i_e]) + ) if i_d < entity_dof_end: - if ti.static(is_backward): + if ti.static(static_rigid_sim_config.is_backward): out_bw[1, i_d, i_b] = out_bw[0, i_d, i_b] * rigid_global_info.mass_mat_D_inv[i_d, i_b] else: out[i_d, i_b] *= rigid_global_info.mass_mat_D_inv[i_d, i_b] @@ -4245,21 +4249,29 @@ def func_solve_mass_batched( # Step 3: Solve x st. L @ x = z for i_d_ in ( range(entity_dof_start, entity_dof_end) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) ): - i_d = i_d_ if ti.static(not is_backward) else (i_d_ + entities_info.dof_start[i_e]) + i_d = ( + i_d_ + if ti.static(not static_rigid_sim_config.is_backward) + else (i_d_ + entities_info.dof_start[i_e]) + ) if i_d < entity_dof_end: curr_out = out[i_d, i_b] - if ti.static(is_backward): + if ti.static(static_rigid_sim_config.is_backward): curr_out = out_bw[1, i_d, i_b] for j_d_ in ( range(entity_dof_start, i_d) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) ): - j_d = j_d_ if ti.static(not is_backward) else (j_d_ + entities_info.dof_start[i_e]) + j_d = ( + j_d_ + if ti.static(not static_rigid_sim_config.is_backward) + else (j_d_ + entities_info.dof_start[i_e]) + ) if j_d < i_d: curr_out += -(rigid_global_info.mass_mat_L[i_d, j_d, i_b] * out[j_d, i_b]) @@ -4274,7 +4286,6 @@ def func_solve_mass( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), - is_backward: ti.template(), ): # This loop must be the outermost loop to be differentiable ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) @@ -4287,7 +4298,6 @@ def func_solve_mass( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, ) @@ -4306,7 +4316,6 @@ def func_forward_dynamics( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), contact_island_state: array_class.ContactIslandState, - is_backward: ti.template(), ): func_compute_mass_matrix( implicit_damping=ti.static(static_rigid_sim_config.integrator == gs.integrator.approximate_implicitfast), @@ -4317,7 +4326,6 @@ def func_forward_dynamics( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, ) func_factor_mass( implicit_damping=False, @@ -4326,7 +4334,6 @@ def func_forward_dynamics( dofs_info=dofs_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, ) func_torque_and_passive_force( entities_state=entities_state, @@ -4340,7 +4347,6 @@ def func_forward_dynamics( rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, contact_island_state=contact_island_state, - is_backward=is_backward, ) func_update_acc( update_cacc=False, @@ -4350,7 +4356,6 @@ def func_forward_dynamics( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, ) func_update_force( links_state=links_state, @@ -4358,7 +4363,6 @@ def func_forward_dynamics( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, ) # self._func_actuation() func_bias_force( @@ -4367,14 +4371,12 @@ def func_forward_dynamics( links_info=links_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, ) func_compute_qacc( dofs_state=dofs_state, entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, ) @@ -4391,7 +4393,6 @@ def kernel_forward_dynamics_without_qacc( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), contact_island_state: array_class.ContactIslandState, - is_backward: ti.template(), ): func_compute_mass_matrix( implicit_damping=ti.static(static_rigid_sim_config.integrator == gs.integrator.approximate_implicitfast), @@ -4402,7 +4403,6 @@ def kernel_forward_dynamics_without_qacc( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, ) func_factor_mass( implicit_damping=False, @@ -4411,7 +4411,6 @@ def kernel_forward_dynamics_without_qacc( dofs_info=dofs_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, ) func_torque_and_passive_force( entities_state=entities_state, @@ -4425,7 +4424,6 @@ def kernel_forward_dynamics_without_qacc( rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, contact_island_state=contact_island_state, - is_backward=is_backward, ) func_update_acc( update_cacc=False, @@ -4435,7 +4433,6 @@ def kernel_forward_dynamics_without_qacc( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, ) func_update_force( links_state=links_state, @@ -4443,7 +4440,6 @@ def kernel_forward_dynamics_without_qacc( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, ) # self._func_actuation() func_bias_force( @@ -4452,7 +4448,6 @@ def kernel_forward_dynamics_without_qacc( links_info=links_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, ) @@ -4483,7 +4478,6 @@ def kernel_update_cartesian_space( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), force_update_fixed_geoms: ti.template(), - is_backward: ti.template(), ): ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_b in range(links_state.pos.shape[1]): @@ -4501,7 +4495,6 @@ def kernel_update_cartesian_space( rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, force_update_fixed_geoms=force_update_fixed_geoms, - is_backward=is_backward, ) @@ -4520,7 +4513,6 @@ def func_update_cartesian_space( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), force_update_fixed_geoms: ti.template(), - is_backward: ti.template(), ): func_forward_kinematics( i_b, @@ -4533,7 +4525,6 @@ def func_update_cartesian_space( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, ) func_COM_links( i_b, @@ -4546,7 +4537,6 @@ def func_update_cartesian_space( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, ) func_forward_velocity( i_b, @@ -4557,7 +4547,6 @@ def func_update_cartesian_space( dofs_state=dofs_state, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, ) func_update_geoms( @@ -4569,7 +4558,6 @@ def func_update_cartesian_space( rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, force_update_fixed_geoms=force_update_fixed_geoms, - is_backward=is_backward, ) @@ -4588,7 +4576,6 @@ def kernel_step_1( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), contact_island_state: array_class.ContactIslandState, - is_backward: ti.template(), ): if ti.static(static_rigid_sim_config.enable_mujoco_compatibility): _B = links_state.pos.shape[1] @@ -4608,7 +4595,6 @@ def kernel_step_1( rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, force_update_fixed_geoms=False, - is_backward=is_backward, ) func_forward_dynamics( @@ -4623,7 +4609,6 @@ def kernel_step_1( rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, contact_island_state=contact_island_state, - is_backward=is_backward, ) @@ -4634,7 +4619,6 @@ def func_implicit_damping( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), - is_backward: ti.template(), ): EPS = rigid_global_info.EPS[None] @@ -4656,10 +4640,12 @@ def func_implicit_damping( entity_dof_end = entities_info.dof_end[i_e] for i_d_ in ( range(entity_dof_start, entity_dof_end) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) ): - i_d = i_d_ if ti.static(not is_backward) else entities_info.dof_start[i_e] + i_d_ + i_d = ( + i_d_ if ti.static(not static_rigid_sim_config.is_backward) else entities_info.dof_start[i_e] + i_d_ + ) if i_d < entity_dof_end: I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d if dofs_info.damping[I_d] > EPS: @@ -4678,7 +4664,6 @@ def func_implicit_damping( dofs_info=dofs_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, ) func_solve_mass( vec=dofs_state.force, @@ -4687,7 +4672,6 @@ def func_implicit_damping( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, ) # Disable pre-computed factorization mask right away @@ -4715,7 +4699,6 @@ def kernel_step_2( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), contact_island_state: array_class.ContactIslandState, - is_backward: ti.template(), ): # Position, Velocity and Acceleration data must be consistent when computing links acceleration, otherwise it # would not corresponds to anyting physical. There is no other way than doing this right before integration, @@ -4730,7 +4713,6 @@ def kernel_step_2( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, ) if ti.static(static_rigid_sim_config.integrator != gs.integrator.approximate_implicitfast): @@ -4740,7 +4722,6 @@ def kernel_step_2( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, ) func_integrate( @@ -4749,7 +4730,6 @@ def kernel_step_2( joints_info=joints_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, ) if ti.static(static_rigid_sim_config.use_hibernation): @@ -4772,7 +4752,7 @@ def kernel_step_2( static_rigid_sim_config=static_rigid_sim_config, ) - if ti.static(not is_backward): + if ti.static(not static_rigid_sim_config.is_backward): func_copy_next_to_curr( dofs_state=dofs_state, rigid_global_info=rigid_global_info, @@ -4797,7 +4777,6 @@ def kernel_step_2( rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, force_update_fixed_geoms=False, - is_backward=is_backward, ) @@ -4815,7 +4794,6 @@ def kernel_forward_kinematics_links_geoms( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), - is_backward: ti.template(), ): for i_b_ in range(envs_idx.shape[0]): i_b = envs_idx[i_b_] @@ -4834,7 +4812,6 @@ def kernel_forward_kinematics_links_geoms( rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, force_update_fixed_geoms=True, - is_backward=is_backward, ) @@ -4850,7 +4827,6 @@ def func_COM_links( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), - is_backward: ti.template(), ): ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_l_ in ( @@ -4860,7 +4836,7 @@ def func_COM_links( if ti.static(static_rigid_sim_config.use_hibernation) else range(links_info.root_idx.shape[0]) ) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ( # Static inner loop for backward pass ti.static(range(static_rigid_sim_config.max_n_awake_links)) @@ -4888,7 +4864,7 @@ def func_COM_links( if ti.static(static_rigid_sim_config.use_hibernation) else range(links_info.root_idx.shape[0]) ) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ( # Static inner loop for backward pass ti.static(range(static_rigid_sim_config.max_n_awake_links)) @@ -4929,7 +4905,7 @@ def func_COM_links( if ti.static(static_rigid_sim_config.use_hibernation) else range(links_info.root_idx.shape[0]) ) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ( # Static inner loop for backward pass ti.static(range(static_rigid_sim_config.max_n_awake_links)) @@ -4959,7 +4935,7 @@ def func_COM_links( if ti.static(static_rigid_sim_config.use_hibernation) else range(links_info.root_idx.shape[0]) ) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ( # Static inner loop for backward pass ti.static(range(static_rigid_sim_config.max_n_awake_links)) @@ -4988,7 +4964,7 @@ def func_COM_links( if ti.static(static_rigid_sim_config.use_hibernation) else range(links_info.root_idx.shape[0]) ) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ( # Static inner loop for backward pass ti.static(range(static_rigid_sim_config.max_n_awake_links)) @@ -5032,7 +5008,7 @@ def func_COM_links( if ti.static(static_rigid_sim_config.use_hibernation) else range(links_info.root_idx.shape[0]) ) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ( # Static inner loop for backward pass ti.static(range(static_rigid_sim_config.max_n_awake_links)) @@ -5076,13 +5052,13 @@ def func_COM_links( for i_j_ in ( range(n_joints) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_joints_per_link)) ): i_j = i_j_ + links_info.joint_start[I_l] - curr_i_j = 0 if ti.static(not is_backward) else i_j_ - next_i_j = 0 if ti.static(not is_backward) else i_j_ + 1 + curr_i_j = 0 if ti.static(not static_rigid_sim_config.is_backward) else i_j_ + next_i_j = 0 if ti.static(not static_rigid_sim_config.is_backward) else i_j_ + 1 if i_j < links_info.joint_end[I_l]: I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j @@ -5097,7 +5073,7 @@ def func_COM_links( links_state.j_quat_bw[i_l, curr_i_j, i_b], ) - i_j_ = 0 if ti.static(not is_backward) else n_joints + i_j_ = 0 if ti.static(not static_rigid_sim_config.is_backward) else n_joints links_state.j_pos[i_l, i_b] = links_state.j_pos_bw[i_l, i_j_, i_b] links_state.j_quat[i_l, i_b] = links_state.j_quat_bw[i_l, i_j_, i_b] @@ -5109,7 +5085,7 @@ def func_COM_links( if ti.static(static_rigid_sim_config.use_hibernation) else range(links_info.root_idx.shape[0]) ) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ( # Static inner loop for backward pass ti.static(range(static_rigid_sim_config.max_n_awake_links)) @@ -5130,10 +5106,14 @@ def func_COM_links( if links_info.n_dofs[I_l] > 0: for i_j_ in ( range(links_info.joint_start[I_l], links_info.joint_end[I_l]) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_joints_per_link)) ): - i_j = i_j_ if ti.static(not is_backward) else (i_j_ + links_info.joint_start[I_l]) + i_j = ( + i_j_ + if ti.static(not static_rigid_sim_config.is_backward) + else (i_j_ + links_info.joint_start[I_l]) + ) if i_j < links_info.joint_end[I_l]: offset_pos = links_state.root_COM[i_l, i_b] - joints_state.xanchor[i_j, i_b] @@ -5167,10 +5147,10 @@ def func_COM_links( for i_d_ in ( range(dof_start, joints_info.dof_end[I_j]) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_dofs_per_joint)) ): - i_d = i_d_ if ti.static(not is_backward) else (i_d_ + dof_start) + i_d = i_d_ if ti.static(not static_rigid_sim_config.is_backward) else (i_d_ + dof_start) if i_d < joints_info.dof_end[I_j]: dofs_state.cdofvel_ang[i_d, i_b] = ( dofs_state.cdof_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] @@ -5192,7 +5172,6 @@ def func_forward_kinematics( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), - is_backward: ti.template(), ): for i_e_ in ( ( @@ -5200,7 +5179,7 @@ def func_forward_kinematics( if ti.static(static_rigid_sim_config.use_hibernation) else range(entities_info.n_links.shape[0]) ) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ( ti.static(range(static_rigid_sim_config.max_n_awake_entities)) if ti.static(static_rigid_sim_config.use_hibernation) @@ -5230,7 +5209,6 @@ def func_forward_kinematics( entities_info, rigid_global_info, static_rigid_sim_config, - is_backward, ) @@ -5244,7 +5222,6 @@ def func_forward_velocity( dofs_state: array_class.DofsState, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), - is_backward: ti.template(), ): n_entities = entities_info.n_links.shape[0] for i_e_ in ( @@ -5254,7 +5231,7 @@ def func_forward_velocity( if ti.static(static_rigid_sim_config.use_hibernation) else range(n_entities) ) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ( # Static inner loop for backward pass ti.static(range(static_rigid_sim_config.max_n_awake_entities)) @@ -5275,7 +5252,6 @@ def func_forward_velocity( dofs_state=dofs_state, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, ) @@ -5292,7 +5268,6 @@ def kernel_forward_kinematics_entity( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), - is_backward: ti.template(), ): for i_b_ in range(envs_idx.shape[0]): i_b = envs_idx[i_b_] @@ -5309,7 +5284,6 @@ def kernel_forward_kinematics_entity( entities_info, rigid_global_info, static_rigid_sim_config, - is_backward, ) @@ -5326,16 +5300,15 @@ def func_forward_kinematics_entity( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), - is_backward: ti.template(), ): # Becomes static loop in backward pass, because we assume this loop is an inner loop for i_l_ in ( range(entities_info.link_start[i_e], entities_info.link_end[i_e]) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) ): EPS = rigid_global_info.EPS[None] - i_l = i_l_ if ti.static(not is_backward) else (i_l_ + entities_info.link_start[i_e]) + i_l = i_l_ if ti.static(not static_rigid_sim_config.is_backward) else (i_l_ + entities_info.link_start[i_e]) if i_l < entities_info.link_end[i_e]: I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l @@ -5352,13 +5325,13 @@ def func_forward_kinematics_entity( for i_j_ in ( range(n_joints) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_joints_per_link)) ): i_j = i_j_ + links_info.joint_start[I_l] - curr_i_j = 0 if ti.static(not is_backward) else i_j_ - next_i_j = 0 if ti.static(not is_backward) else i_j_ + 1 + curr_i_j = 0 if ti.static(not static_rigid_sim_config.is_backward) else i_j_ + next_i_j = 0 if ti.static(not static_rigid_sim_config.is_backward) else i_j_ + 1 if i_j < links_info.joint_end[I_l]: I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j @@ -5459,7 +5432,7 @@ def func_forward_kinematics_entity( ) # Skip link pose update for fixed root links to let users manually overwrite them - i_j_ = 0 if ti.static(not is_backward) else n_joints + i_j_ = 0 if ti.static(not static_rigid_sim_config.is_backward) else n_joints if not (links_info.parent_idx[I_l] == -1 and links_info.is_fixed[I_l]): links_state.pos[i_l, i_b] = links_state.pos_bw[i_l, i_j_, i_b] links_state.quat[i_l, i_b] = links_state.quat_bw[i_l, i_j_, i_b] @@ -5476,14 +5449,13 @@ def func_forward_velocity_entity( dofs_state: array_class.DofsState, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), - is_backward: ti.template(), ): for i_l_ in ( range(entities_info.link_start[i_e], entities_info.link_end[i_e]) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) ): - i_l = i_l_ if ti.static(not is_backward) else (i_l_ + entities_info.link_start[i_e]) + i_l = i_l_ if ti.static(not static_rigid_sim_config.is_backward) else (i_l_ + entities_info.link_start[i_e]) if i_l < entities_info.link_end[i_e]: I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l @@ -5498,7 +5470,7 @@ def func_forward_velocity_entity( for i_j_ in ( range(n_joints) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_joints_per_link)) ): i_j = i_j_ + links_info.joint_start[I_l] @@ -5509,8 +5481,8 @@ def func_forward_velocity_entity( q_start = joints_info.q_start[I_j] dof_start = joints_info.dof_start[I_j] - curr_i_j = 0 if ti.static(not is_backward) else i_j_ - next_i_j = 0 if ti.static(not is_backward) else i_j_ + 1 + curr_i_j = 0 if ti.static(not static_rigid_sim_config.is_backward) else i_j_ + next_i_j = 0 if ti.static(not static_rigid_sim_config.is_backward) else i_j_ + 1 if joint_type == gs.JOINT_TYPE.FREE: for i_3 in ti.static(range(3)): @@ -5551,10 +5523,10 @@ def func_forward_velocity_entity( else: for i_d_ in ( range(dof_start, joints_info.dof_end[I_j]) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_dofs_per_joint)) ): - i_d = i_d_ if ti.static(not is_backward) else (i_d_ + dof_start) + i_d = i_d_ if ti.static(not static_rigid_sim_config.is_backward) else (i_d_ + dof_start) if i_d < joints_info.dof_end[I_j]: dofs_state.cdofd_ang[i_d, i_b], dofs_state.cdofd_vel[i_d, i_b] = gu.motion_cross_motion( links_state.cd_ang_bw[i_l, curr_i_j, i_b], @@ -5568,10 +5540,10 @@ def func_forward_velocity_entity( for i_d_ in ( range(dof_start, joints_info.dof_end[I_j]) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_dofs_per_joint)) ): - i_d = i_d_ if ti.static(not is_backward) else (i_d_ + dof_start) + i_d = i_d_ if ti.static(not static_rigid_sim_config.is_backward) else (i_d_ + dof_start) if i_d < joints_info.dof_end[I_j]: links_state.cd_vel_bw[i_l, next_i_j, i_b] += ( dofs_state.cdof_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] @@ -5580,7 +5552,7 @@ def func_forward_velocity_entity( dofs_state.cdof_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] ) - i_j_ = 0 if ti.static(not is_backward) else n_joints + i_j_ = 0 if ti.static(not static_rigid_sim_config.is_backward) else n_joints links_state.cd_vel[i_l, i_b] = links_state.cd_vel_bw[i_l, i_j_, i_b] links_state.cd_ang[i_l, i_b] = links_state.cd_ang_bw[i_l, i_j_, i_b] @@ -5595,7 +5567,6 @@ def kernel_update_geoms( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), force_update_fixed_geoms: ti.template(), - is_backward: ti.template(), ): for i_b_ in range(envs_idx.shape[0]): i_b = envs_idx[i_b_] @@ -5609,7 +5580,6 @@ def kernel_update_geoms( rigid_global_info, static_rigid_sim_config, force_update_fixed_geoms, - is_backward, ) @@ -5623,7 +5593,6 @@ def func_update_geoms( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), force_update_fixed_geoms: ti.template(), - is_backward: ti.template(), ): """ NOTE: this only update geom pose, not its verts and else. @@ -5635,7 +5604,7 @@ def func_update_geoms( if ti.static(static_rigid_sim_config.use_hibernation) else range(geoms_info.pos.shape[0]) ) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ( # Static inner loop for backward pass ti.static(range(static_rigid_sim_config.max_n_awake_entities)) @@ -5653,7 +5622,7 @@ def func_update_geoms( if ti.static(static_rigid_sim_config.use_hibernation) else range(1) ) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ( # Static inner loop for backward pass ti.static(range(static_rigid_sim_config.max_n_geoms_per_entity)) @@ -6054,7 +6023,6 @@ def func_torque_and_passive_force( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), contact_island_state: array_class.ContactIslandState, - is_backward: ti.template(), ): # compute force based on each dof's ctrl mode ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) @@ -6064,10 +6032,10 @@ def func_torque_and_passive_force( for i_l_ in ( range(entities_info.link_start[i_e], entities_info.link_end[i_e]) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) ): - i_l = i_l_ if ti.static(not is_backward) else (i_l_ + entities_info.link_start[i_e]) + i_l = i_l_ if ti.static(not static_rigid_sim_config.is_backward) else (i_l_ + entities_info.link_start[i_e]) if i_l < entities_info.link_end[i_e]: I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l @@ -6078,10 +6046,14 @@ def func_torque_and_passive_force( for i_d_ in ( range(links_info.dof_start[I_l], links_info.dof_end[I_l]) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_dofs_per_link)) ): - i_d = i_d_ if ti.static(not is_backward) else (i_d_ + links_info.dof_start[I_l]) + i_d = ( + i_d_ + if ti.static(not static_rigid_sim_config.is_backward) + else (i_d_ + links_info.dof_start[I_l]) + ) if i_d < links_info.dof_end[I_l]: I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d @@ -6176,7 +6148,7 @@ def func_torque_and_passive_force( if ti.static(static_rigid_sim_config.use_hibernation) else range(1) ) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ( # Static inner for backward pass ti.static(range(static_rigid_sim_config.max_n_awake_dofs)) @@ -6207,7 +6179,7 @@ def func_torque_and_passive_force( if ti.static(static_rigid_sim_config.use_hibernation) else range(1) ) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ( # Static inner for backward pass ti.static(range(static_rigid_sim_config.max_n_awake_links)) @@ -6237,7 +6209,7 @@ def func_torque_and_passive_force( for j_d in ( range(dof_end - dof_start) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_dofs_per_link)) ): if j_d < dof_end: @@ -6260,7 +6232,6 @@ def func_update_acc( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), - is_backward: ti.template(), ): # Assume this is the outermost loop ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) @@ -6276,7 +6247,7 @@ def func_update_acc( if ti.static(static_rigid_sim_config.use_hibernation) else range(1) ) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ( # Static inner loop for backward pass ti.static(range(static_rigid_sim_config.max_n_awake_entities)) @@ -6295,10 +6266,14 @@ def func_update_acc( for i_l_ in ( range(entities_info.link_start[i_e], entities_info.link_end[i_e]) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) ): - i_l = i_l_ if ti.static(not is_backward) else (i_l_ + entities_info.link_start[i_e]) + i_l = ( + i_l_ + if ti.static(not static_rigid_sim_config.is_backward) + else (i_l_ + entities_info.link_start[i_e]) + ) if i_l < entities_info.link_end[i_e]: I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l @@ -6321,10 +6296,14 @@ def func_update_acc( for i_d_ in ( range(links_info.dof_start[I_l], links_info.dof_end[I_l]) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_dofs_per_link)) ): - i_d = i_d_ if ti.static(not is_backward) else (i_d_ + links_info.dof_start[I_l]) + i_d = ( + i_d_ + if ti.static(not static_rigid_sim_config.is_backward) + else (i_d_ + links_info.dof_start[I_l]) + ) if i_d < links_info.dof_end[I_l]: # cacc = cacc_parent + cdofdot * qvel + cdof * qacc @@ -6348,7 +6327,6 @@ def func_update_force( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), - is_backward: ti.template(), ): ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_0, i_b in ( @@ -6363,7 +6341,7 @@ def func_update_force( if ti.static(static_rigid_sim_config.use_hibernation) else range(1) ) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ( # Static inner loop for backward pass ti.static(range(static_rigid_sim_config.max_n_awake_links)) @@ -6418,7 +6396,7 @@ def func_update_force( if ti.static(static_rigid_sim_config.use_hibernation) else range(1) ) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ( # Static inner loop for backward pass ti.static(range(static_rigid_sim_config.max_n_awake_entities)) @@ -6437,7 +6415,7 @@ def func_update_force( for i_l_ in ( range(entities_info.n_links[i_e]) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) ): if i_l_ < entities_info.n_links[i_e]: @@ -6486,7 +6464,6 @@ def func_bias_force( links_info: array_class.LinksInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), - is_backward: ti.template(), ): ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_0, i_b in ( @@ -6501,7 +6478,7 @@ def func_bias_force( if ti.static(static_rigid_sim_config.use_hibernation) else range(1) ) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ( # Static inner loop for backward pass ti.static(range(static_rigid_sim_config.max_n_awake_links)) @@ -6521,10 +6498,14 @@ def func_bias_force( for i_d_ in ( range(links_info.dof_start[I_l], links_info.dof_end[I_l]) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_dofs_per_link)) ): - i_d = i_d_ if ti.static(not is_backward) else (i_d_ + links_info.dof_start[I_l]) + i_d = ( + i_d_ + if ti.static(not static_rigid_sim_config.is_backward) + else (i_d_ + links_info.dof_start[I_l]) + ) if i_d < links_info.dof_end[I_l]: dofs_state.qf_bias[i_d, i_b] = dofs_state.cdof_ang[i_d, i_b].dot( links_state.cfrc_ang[i_l, i_b] @@ -6546,14 +6527,12 @@ def kernel_compute_qacc( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), - is_backward: ti.template(), ): func_compute_qacc( dofs_state=dofs_state, entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, ) @@ -6563,7 +6542,6 @@ def func_compute_qacc( entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), - is_backward: ti.template(), ): func_solve_mass( vec=dofs_state.force, @@ -6572,7 +6550,6 @@ def func_compute_qacc( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, ) # Assume this is the outermost loop @@ -6589,7 +6566,7 @@ def func_compute_qacc( if ti.static(static_rigid_sim_config.use_hibernation) else range(1) ) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ( # Static inner loop for backward pass ti.static(range(static_rigid_sim_config.max_n_awake_entities)) @@ -6608,7 +6585,7 @@ def func_compute_qacc( for i_d1_ in ( range(entities_info.n_dofs[i_e]) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) ): i_d1 = entities_info.dof_start[i_e] + i_d1_ @@ -6623,7 +6600,6 @@ def func_integrate( joints_info: array_class.JointsInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), - is_backward: ti.template(), ): ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_0, i_b in ( @@ -6638,7 +6614,7 @@ def func_integrate( if ti.static(static_rigid_sim_config.use_hibernation) else range(1) ) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ( # Static inner loop for backward pass ti.static(range(static_rigid_sim_config.max_n_awake_dofs)) @@ -6670,7 +6646,7 @@ def func_integrate( if ti.static(static_rigid_sim_config.use_hibernation) else range(1) ) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else ( # Static inner loop for backward pass ti.static(range(static_rigid_sim_config.max_n_awake_links)) @@ -6742,7 +6718,7 @@ def func_integrate( else: for j_ in ( (range(q_end - q_start)) - if ti.static(not is_backward) + if ti.static(not static_rigid_sim_config.is_backward) else (ti.static(range(static_rigid_sim_config.max_n_qs_per_link))) ): j = q_start + j_ @@ -6896,7 +6872,6 @@ def kernel_prepare_backward_substep( rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, force_update_fixed_geoms=False, - is_backward=False, ) # Save results of [update_cartesian_space] to adjoint cache func_copy_cartesian_space( diff --git a/genesis/utils/array_class.py b/genesis/utils/array_class.py index 76922313f..534275f3c 100644 --- a/genesis/utils/array_class.py +++ b/genesis/utils/array_class.py @@ -1799,6 +1799,7 @@ class StructRigidSimStaticConfig(metaclass=AutoInitMeta): sparse_solve: bool integrator: int solver_type: int + is_backward: bool # =========================================== DataManager =========================================== From 30cb2c593d263a8a4c0300b586de126b11c1b56e Mon Sep 17 00:00:00 2001 From: Sanghyun Date: Wed, 26 Nov 2025 19:20:06 -0800 Subject: [PATCH 06/12] completely removed is_backward param --- genesis/engine/entities/rigid_entity/rigid_entity.py | 2 -- genesis/engine/solvers/rigid/constraint_noslip.py | 3 --- genesis/utils/path_planning.py | 4 ---- tests/test_grad.py | 1 - 4 files changed, 10 deletions(-) diff --git a/genesis/engine/entities/rigid_entity/rigid_entity.py b/genesis/engine/entities/rigid_entity/rigid_entity.py index fd9ea8a48..586c61252 100644 --- a/genesis/engine/entities/rigid_entity/rigid_entity.py +++ b/genesis/engine/entities/rigid_entity/rigid_entity.py @@ -1477,7 +1477,6 @@ def _kernel_forward_kinematics( entities_info, rigid_global_info, static_rigid_sim_config, - is_backward=False, ) ti.loop_config(serialize=ti.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) @@ -1507,7 +1506,6 @@ def _kernel_forward_kinematics( entities_info, rigid_global_info, static_rigid_sim_config, - is_backward=False, ) # ------------------------------------------------------------------------------------ diff --git a/genesis/engine/solvers/rigid/constraint_noslip.py b/genesis/engine/solvers/rigid/constraint_noslip.py index 8e02a545c..b1dc3f153 100644 --- a/genesis/engine/solvers/rigid/constraint_noslip.py +++ b/genesis/engine/solvers/rigid/constraint_noslip.py @@ -40,7 +40,6 @@ def kernel_build_efc_AR_b( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=False, ) # AR[r, c] = J[c, :] * tmp @@ -198,7 +197,6 @@ def kernel_dual_finish( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=False, ) for i_d in range(n_dofs): @@ -292,7 +290,6 @@ def compute_A_diag( entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=False, ) # Ai = Ji * tmp diff --git a/genesis/utils/path_planning.py b/genesis/utils/path_planning.py index 40d82198c..4cd4a3c4e 100644 --- a/genesis/utils/path_planning.py +++ b/genesis/utils/path_planning.py @@ -429,7 +429,6 @@ def _kernel_rrt_step1( entities_info, rigid_global_info, self._solver._static_rigid_sim_config, - is_backward=False, ) gs.engine.solvers.rigid.rigid_solver_decomp.func_update_geoms( i_b, @@ -440,7 +439,6 @@ def _kernel_rrt_step1( rigid_global_info, self._solver._static_rigid_sim_config, force_update_fixed_geoms=False, - is_backward=False, ) @ti.kernel @@ -799,7 +797,6 @@ def _kernel_rrt_connect_step1( entities_info, rigid_global_info, self._solver._static_rigid_sim_config, - is_backward=False, ) gs.engine.solvers.rigid.rigid_solver_decomp.func_update_geoms( i_b, @@ -810,7 +807,6 @@ def _kernel_rrt_connect_step1( rigid_global_info, self._solver._static_rigid_sim_config, force_update_fixed_geoms=False, - is_backward=False, ) @ti.kernel diff --git a/tests/test_grad.py b/tests/test_grad.py index 5dbe31c51..2349f1129 100644 --- a/tests/test_grad.py +++ b/tests/test_grad.py @@ -206,7 +206,6 @@ def constraint_solver_resolve(): rigid_global_info=rigid_solver._rigid_global_info, static_rigid_sim_config=rigid_solver._static_rigid_sim_config, contact_island_state=constraint_solver.contact_island.contact_island_state, - is_backward=False, ) constraint_solver.add_equality_constraints() rigid_solver.collider.detection() From e4227799706c0efd807372cac73356b6326c0241 Mon Sep 17 00:00:00 2001 From: Sanghyun Date: Wed, 26 Nov 2025 19:53:58 -0800 Subject: [PATCH 07/12] removed is_backward parameter --- genesis/engine/entities/rigid_entity/rigid_entity.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/genesis/engine/entities/rigid_entity/rigid_entity.py b/genesis/engine/entities/rigid_entity/rigid_entity.py index 586c61252..b4ee4d358 100644 --- a/genesis/engine/entities/rigid_entity/rigid_entity.py +++ b/genesis/engine/entities/rigid_entity/rigid_entity.py @@ -3367,7 +3367,6 @@ def kernel_rigid_entity_inverse_kinematics( entities_info, rigid_global_info, static_rigid_sim_config, - False, ) # compute error solved = True @@ -3496,7 +3495,6 @@ def kernel_rigid_entity_inverse_kinematics( entities_info, rigid_global_info, static_rigid_sim_config, - False, ) solved = True for i_ee in range(n_links): @@ -3596,5 +3594,4 @@ def kernel_rigid_entity_inverse_kinematics( entities_info, rigid_global_info, static_rigid_sim_config, - False, ) From 7ddc810ec0c3b86c4211ec0f79d54d04d13fe6eb Mon Sep 17 00:00:00 2001 From: Sanghyun Date: Wed, 26 Nov 2025 23:24:20 -0800 Subject: [PATCH 08/12] do not store target values when not differentiable --- genesis/engine/entities/rigid_entity/rigid_entity.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/genesis/engine/entities/rigid_entity/rigid_entity.py b/genesis/engine/entities/rigid_entity/rigid_entity.py index d8ed1ca42..dee0dcaf6 100644 --- a/genesis/engine/entities/rigid_entity/rigid_entity.py +++ b/genesis/engine/entities/rigid_entity/rigid_entity.py @@ -103,7 +103,7 @@ def __init__( self._tgt = dict() self._tgt_buffer = list() self._ckpt = dict() - self._update_tgt_while_set = True + self._update_tgt_while_set = self._solver._requires_grad def _update_tgt(self, key, value): # Set [self._tgt] value while keeping the insertion order between keys. When a new key is inserted or an existing @@ -1653,6 +1653,7 @@ def process_input(self, in_backward=False): else: self._tgt_buffer.append(self._tgt.copy()) + update_tgt_while_set = self._update_tgt_while_set # Apply targets in the order of insertion for key in self._tgt.keys(): data_kwargs = self._tgt[key] @@ -1677,7 +1678,7 @@ def process_input(self, in_backward=False): gs.raise_exception(f"Invalid target key: {key} not in {self._tgt_keys}") self._tgt = dict() - self._update_tgt_while_set = True + self._update_tgt_while_set = update_tgt_while_set def process_input_grad(self): index = self._sim.cur_step_local - self._sim._steps_local From fc8a20669022fed449b46041163d5f2ceed47a21 Mon Sep 17 00:00:00 2001 From: Sanghyun Date: Thu, 27 Nov 2025 14:31:15 -0800 Subject: [PATCH 09/12] special handling when there is only non-differentiable rigid solver --- genesis/engine/simulator.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/genesis/engine/simulator.py b/genesis/engine/simulator.py index 9b27ab2b2..67fd4d843 100644 --- a/genesis/engine/simulator.py +++ b/genesis/engine/simulator.py @@ -266,13 +266,18 @@ def f_global_to_s_global(self, f_global): # ------------------------------------------------------------------------------------ def step(self, in_backward=False): - self.process_input(in_backward=in_backward) - for _ in range(self._substeps): - self.substep(self.cur_substep_local) + if self._rigid_only and (not self._requires_grad): # "Only Advance!" --Thomas Wade :P + for _ in range(self._substeps): + self.rigid_solver.substep(self.cur_substep_local) + self._cur_substep_global += 1 + else: + self.process_input(in_backward=in_backward) + for _ in range(self._substeps): + self.substep(self.cur_substep_local) - self._cur_substep_global += 1 - if self.cur_substep_local == 0 and not in_backward: - self.save_ckpt() + self._cur_substep_global += 1 + if self.cur_substep_local == 0 and not in_backward: + self.save_ckpt() if self.rigid_solver.is_active: self.rigid_solver.clear_external_force() From 47cfd3dcdf6040a3220cd982bbb7443f9953c3c5 Mon Sep 17 00:00:00 2001 From: Sanghyun Date: Thu, 27 Nov 2025 18:47:14 -0800 Subject: [PATCH 10/12] remove unnecessary atomic add in forward pass for speedup --- .../solvers/rigid/rigid_solver_decomp.py | 152 ++++++++++++++---- 1 file changed, 117 insertions(+), 35 deletions(-) diff --git a/genesis/engine/solvers/rigid/rigid_solver_decomp.py b/genesis/engine/solvers/rigid/rigid_solver_decomp.py index d68d8b259..047ce10cc 100644 --- a/genesis/engine/solvers/rigid/rigid_solver_decomp.py +++ b/genesis/engine/solvers/rigid/rigid_solver_decomp.py @@ -3725,10 +3725,25 @@ def func_compute_mass_matrix( i_p = links_info.parent_idx[I_l] if i_p != -1: - links_state.crb_inertial[i_p, i_b] += links_state.crb_inertial[i_l, i_b] - links_state.crb_mass[i_p, i_b] += links_state.crb_mass[i_l, i_b] - links_state.crb_pos[i_p, i_b] += links_state.crb_pos[i_l, i_b] - links_state.crb_quat[i_p, i_b] += links_state.crb_quat[i_l, i_b] + # Backward pass requires atomic add + if ti.static(static_rigid_sim_config.is_backward): + links_state.crb_inertial[i_p, i_b] += links_state.crb_inertial[i_l, i_b] + links_state.crb_mass[i_p, i_b] += links_state.crb_mass[i_l, i_b] + links_state.crb_pos[i_p, i_b] += links_state.crb_pos[i_l, i_b] + links_state.crb_quat[i_p, i_b] += links_state.crb_quat[i_l, i_b] + else: + links_state.crb_inertial[i_p, i_b] = ( + links_state.crb_inertial[i_p, i_b] + links_state.crb_inertial[i_l, i_b] + ) + links_state.crb_mass[i_p, i_b] = ( + links_state.crb_mass[i_p, i_b] + links_state.crb_mass[i_l, i_b] + ) + links_state.crb_pos[i_p, i_b] = ( + links_state.crb_pos[i_p, i_b] + links_state.crb_pos[i_l, i_b] + ) + links_state.crb_quat[i_p, i_b] = ( + links_state.crb_quat[i_p, i_b] + links_state.crb_quat[i_l, i_b] + ) # mass_mat ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) @@ -3867,7 +3882,13 @@ def func_compute_mass_matrix( ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_d, i_b in ti.ndrange(dofs_state.f_ang.shape[0], links_state.pos.shape[1]): I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - rigid_global_info.mass_mat[i_d, i_d, i_b] += dofs_info.armature[I_d] + # Backward pass requires atomic add + if ti.static(static_rigid_sim_config.is_backward): + rigid_global_info.mass_mat[i_d, i_d, i_b] += dofs_info.armature[I_d] + else: + rigid_global_info.mass_mat[i_d, i_d, i_b] = ( + rigid_global_info.mass_mat[i_d, i_d, i_b] + dofs_info.armature[I_d] + ) # Take into account first-order correction terms for implicit integration scheme right away if ti.static(implicit_damping): @@ -5517,12 +5538,23 @@ def func_forward_velocity_entity( if joint_type == gs.JOINT_TYPE.FREE: for i_3 in ti.static(range(3)): - links_state.cd_vel_bw[i_l, curr_i_j, i_b] += ( - dofs_state.cdof_vel[dof_start + i_3, i_b] * dofs_state.vel[dof_start + i_3, i_b] - ) - links_state.cd_ang_bw[i_l, curr_i_j, i_b] += ( - dofs_state.cdof_ang[dof_start + i_3, i_b] * dofs_state.vel[dof_start + i_3, i_b] - ) + # Backward pass requires atomic add + if ti.static(static_rigid_sim_config.is_backward): + links_state.cd_vel_bw[i_l, curr_i_j, i_b] += ( + dofs_state.cdof_vel[dof_start + i_3, i_b] * dofs_state.vel[dof_start + i_3, i_b] + ) + links_state.cd_ang_bw[i_l, curr_i_j, i_b] += ( + dofs_state.cdof_ang[dof_start + i_3, i_b] * dofs_state.vel[dof_start + i_3, i_b] + ) + else: + links_state.cd_vel_bw[i_l, curr_i_j, i_b] = ( + links_state.cd_vel_bw[i_l, curr_i_j, i_b] + + dofs_state.cdof_vel[dof_start + i_3, i_b] * dofs_state.vel[dof_start + i_3, i_b] + ) + links_state.cd_ang_bw[i_l, curr_i_j, i_b] = ( + links_state.cd_ang_bw[i_l, curr_i_j, i_b] + + dofs_state.cdof_ang[dof_start + i_3, i_b] * dofs_state.vel[dof_start + i_3, i_b] + ) for i_3 in ti.static(range(3)): ( @@ -5544,12 +5576,27 @@ def func_forward_velocity_entity( links_state.cd_ang_bw[i_l, next_i_j, i_b] = links_state.cd_ang_bw[i_l, curr_i_j, i_b] for i_3 in ti.static(range(3)): - links_state.cd_vel_bw[i_l, next_i_j, i_b] += ( - dofs_state.cdof_vel[dof_start + i_3 + 3, i_b] * dofs_state.vel[dof_start + i_3 + 3, i_b] - ) - links_state.cd_ang_bw[i_l, next_i_j, i_b] += ( - dofs_state.cdof_ang[dof_start + i_3 + 3, i_b] * dofs_state.vel[dof_start + i_3 + 3, i_b] - ) + # Backward pass requires atomic add + if ti.static(static_rigid_sim_config.is_backward): + links_state.cd_vel_bw[i_l, next_i_j, i_b] += ( + dofs_state.cdof_vel[dof_start + i_3 + 3, i_b] + * dofs_state.vel[dof_start + i_3 + 3, i_b] + ) + links_state.cd_ang_bw[i_l, next_i_j, i_b] += ( + dofs_state.cdof_ang[dof_start + i_3 + 3, i_b] + * dofs_state.vel[dof_start + i_3 + 3, i_b] + ) + else: + links_state.cd_vel_bw[i_l, next_i_j, i_b] = ( + links_state.cd_vel_bw[i_l, next_i_j, i_b] + + dofs_state.cdof_vel[dof_start + i_3 + 3, i_b] + * dofs_state.vel[dof_start + i_3 + 3, i_b] + ) + links_state.cd_ang_bw[i_l, next_i_j, i_b] = ( + links_state.cd_ang_bw[i_l, next_i_j, i_b] + + dofs_state.cdof_ang[dof_start + i_3 + 3, i_b] + * dofs_state.vel[dof_start + i_3 + 3, i_b] + ) else: for i_d_ in ( @@ -5576,12 +5623,23 @@ def func_forward_velocity_entity( ): i_d = i_d_ if ti.static(not static_rigid_sim_config.is_backward) else (i_d_ + dof_start) if i_d < joints_info.dof_end[I_j]: - links_state.cd_vel_bw[i_l, next_i_j, i_b] += ( - dofs_state.cdof_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] - ) - links_state.cd_ang_bw[i_l, next_i_j, i_b] += ( - dofs_state.cdof_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] - ) + # Backward pass requires atomic add + if ti.static(static_rigid_sim_config.is_backward): + links_state.cd_vel_bw[i_l, next_i_j, i_b] += ( + dofs_state.cdof_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] + ) + links_state.cd_ang_bw[i_l, next_i_j, i_b] += ( + dofs_state.cdof_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] + ) + else: + links_state.cd_vel_bw[i_l, next_i_j, i_b] = ( + links_state.cd_vel_bw[i_l, next_i_j, i_b] + + dofs_state.cdof_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] + ) + links_state.cd_ang_bw[i_l, next_i_j, i_b] = ( + links_state.cd_ang_bw[i_l, next_i_j, i_b] + + dofs_state.cdof_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] + ) i_j_ = 0 if ti.static(not static_rigid_sim_config.is_backward) else n_joints links_state.cd_vel[i_l, i_b] = links_state.cd_vel_bw[i_l, i_j_, i_b] @@ -6340,15 +6398,27 @@ def func_update_acc( # cacc = cacc_parent + cdofdot * qvel + cdof * qacc local_cdd_vel = dofs_state.cdofd_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] local_cdd_ang = dofs_state.cdofd_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] - links_state.cdd_vel[i_l, i_b] += local_cdd_vel - links_state.cdd_ang[i_l, i_b] += local_cdd_ang - if ti.static(update_cacc): - links_state.cacc_lin[i_l, i_b] += ( - local_cdd_vel + dofs_state.cdof_vel[i_d, i_b] * dofs_state.acc[i_d, i_b] - ) - links_state.cacc_ang[i_l, i_b] += ( - local_cdd_ang + dofs_state.cdof_ang[i_d, i_b] * dofs_state.acc[i_d, i_b] - ) + # Backward pass requires atomic add + if ti.static(static_rigid_sim_config.is_backward): + links_state.cdd_vel[i_l, i_b] += local_cdd_vel + links_state.cdd_ang[i_l, i_b] += local_cdd_ang + if ti.static(update_cacc): + links_state.cacc_lin[i_l, i_b] += ( + local_cdd_vel + dofs_state.cdof_vel[i_d, i_b] * dofs_state.acc[i_d, i_b] + ) + links_state.cacc_ang[i_l, i_b] += ( + local_cdd_ang + dofs_state.cdof_ang[i_d, i_b] * dofs_state.acc[i_d, i_b] + ) + else: + links_state.cdd_vel[i_l, i_b] = links_state.cdd_vel[i_l, i_b] + local_cdd_vel + links_state.cdd_ang[i_l, i_b] = links_state.cdd_ang[i_l, i_b] + local_cdd_ang + if ti.static(update_cacc): + links_state.cacc_lin[i_l, i_b] = links_state.cacc_lin[i_l, i_b] + ( + local_cdd_vel + dofs_state.cdof_vel[i_d, i_b] * dofs_state.acc[i_d, i_b] + ) + links_state.cacc_ang[i_l, i_b] = links_state.cacc_ang[i_l, i_b] + ( + local_cdd_ang + dofs_state.cdof_ang[i_d, i_b] * dofs_state.acc[i_d, i_b] + ) @ti.func @@ -6454,8 +6524,16 @@ def func_update_force( I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l i_p = links_info.parent_idx[I_l] if i_p != -1: - links_state.cfrc_vel[i_p, i_b] += links_state.cfrc_vel[i_l, i_b] - links_state.cfrc_ang[i_p, i_b] += links_state.cfrc_ang[i_l, i_b] + if ti.static(static_rigid_sim_config.is_backward): + links_state.cfrc_vel[i_p, i_b] += links_state.cfrc_vel[i_l, i_b] + links_state.cfrc_ang[i_p, i_b] += links_state.cfrc_ang[i_l, i_b] + else: + links_state.cfrc_vel[i_p, i_b] = ( + links_state.cfrc_vel[i_p, i_b] + links_state.cfrc_vel[i_l, i_b] + ) + links_state.cfrc_ang[i_p, i_b] = ( + links_state.cfrc_ang[i_p, i_b] + links_state.cfrc_ang[i_l, i_b] + ) # Clear coupling forces after use ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) @@ -6724,7 +6802,11 @@ def func_integrate( dofs_state.vel_next[dof_start + 2, i_b], ] ) - pos += vel * rigid_global_info.substep_dt[None] + # Backward pass requires atomic add + if ti.static(static_rigid_sim_config.is_backward): + pos += vel * rigid_global_info.substep_dt[None] + else: + pos = pos + vel * rigid_global_info.substep_dt[None] for j in ti.static(range(3)): rigid_global_info.qpos_next[q_start + j, i_b] = pos[j] if joint_type == gs.JOINT_TYPE.SPHERICAL or joint_type == gs.JOINT_TYPE.FREE: From 38d586524bbc470a8f0079624517d20a81d4c834 Mon Sep 17 00:00:00 2001 From: Sanghyun Date: Thu, 27 Nov 2025 23:44:48 -0800 Subject: [PATCH 11/12] remove index checking in forward pass for speedup --- .../solvers/rigid/rigid_solver_decomp.py | 383 +++++++++++++----- 1 file changed, 271 insertions(+), 112 deletions(-) diff --git a/genesis/engine/solvers/rigid/rigid_solver_decomp.py b/genesis/engine/solvers/rigid/rigid_solver_decomp.py index 047ce10cc..cb56aeccb 100644 --- a/genesis/engine/solvers/rigid/rigid_solver_decomp.py +++ b/genesis/engine/solvers/rigid/rigid_solver_decomp.py @@ -3669,9 +3669,12 @@ def func_compute_mass_matrix( else ti.static(range(1)) ) ): - if i_1 < ( - rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 - ): + i_1_valid = True + if ti.static(static_rigid_sim_config.is_backward): + i_1_valid = i_1 < ( + rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ) + if i_1_valid: i_l = ( rigid_global_info.awake_links[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) @@ -3705,9 +3708,12 @@ def func_compute_mass_matrix( else ti.static(range(1)) ) ): - if i_1 < ( - rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 - ): + i_1_valid = True + if ti.static(static_rigid_sim_config.is_backward): + i_1_valid = i_1 < ( + rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ) + if i_1_valid: i_e = ( rigid_global_info.awake_entities[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) @@ -3719,7 +3725,10 @@ def func_compute_mass_matrix( if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) ): - if i < entities_info.n_links[i_e]: + i_valid = ( + True if ti.static(not static_rigid_sim_config.is_backward) else (i < entities_info.n_links[i_e]) + ) + if i_valid: i_l = entities_info.link_end[i_e] - 1 - i I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l i_p = links_info.parent_idx[I_l] @@ -3767,9 +3776,12 @@ def func_compute_mass_matrix( else ti.static(range(1)) ) ): - if i_1 < ( - rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 - ): + i_1_valid = True + if ti.static(static_rigid_sim_config.is_backward): + i_1_valid = i_1 < ( + rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ) + if i_1_valid: i_l = ( rigid_global_info.awake_links[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) @@ -3785,8 +3797,10 @@ def func_compute_mass_matrix( i_d = ( i_d_ if ti.static(not static_rigid_sim_config.is_backward) else links_info.dof_start[I_l] + i_d_ ) - - if i_d < links_info.dof_end[I_l]: + i_d_valid = ( + True if ti.static(not static_rigid_sim_config.is_backward) else (i_d < links_info.dof_end[I_l]) + ) + if i_d_valid: dofs_state.f_ang[i_d, i_b], dofs_state.f_vel[i_d, i_b] = gu.inertial_mul( links_state.crb_pos[i_l, i_b], links_state.crb_inertial[i_l, i_b], @@ -3816,9 +3830,12 @@ def func_compute_mass_matrix( else ti.static(range(1)) ) ): - if i_1 < ( - rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 - ): + i_1_valid = True + if ti.static(static_rigid_sim_config.is_backward): + i_1_valid = i_1 < ( + rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ) + if i_1_valid: i_e = ( rigid_global_info.awake_entities[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) @@ -3854,8 +3871,17 @@ def func_compute_mass_matrix( if ti.static(not static_rigid_sim_config.is_backward) else entities_info.dof_start[i_e] + j_d_ ) - - if i_d < entities_info.dof_end[i_e] and j_d < entities_info.dof_end[i_e]: + i_d_valid = ( + True + if ti.static(not static_rigid_sim_config.is_backward) + else (i_d < entities_info.dof_end[i_e]) + ) + j_d_valid = ( + True + if ti.static(not static_rigid_sim_config.is_backward) + else (j_d < entities_info.dof_end[i_e]) + ) + if i_d_valid and j_d_valid: rigid_global_info.mass_mat[i_d, j_d, i_b] = ( dofs_state.f_ang[i_d, i_b].dot(dofs_state.cdof_ang[j_d, i_b]) + dofs_state.f_vel[i_d, i_b].dot(dofs_state.cdof_vel[j_d, i_b]) @@ -4186,12 +4212,15 @@ def func_solve_mass_batched( ) ): n_entities = entities_info.n_links.shape[0] + i_0_valid = True + if ti.static(static_rigid_sim_config.is_backward): + i_0_valid = i_0 < ( + rigid_global_info.n_awake_entities[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else n_entities + ) - if i_0 < ( - rigid_global_info.n_awake_entities[i_b] - if ti.static(static_rigid_sim_config.use_hibernation) - else n_entities - ): + if i_0_valid: i_e = ( rigid_global_info.awake_entities[i_0, i_b] if ti.static(static_rigid_sim_config.use_hibernation) @@ -4209,7 +4238,8 @@ def func_solve_mass_batched( if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) ): - if i_d_ < n_dofs: + i_d_valid = True if ti.static(not static_rigid_sim_config.is_backward) else (i_d_ < n_dofs) + if i_d_valid: i_d = entity_dof_end - i_d_ - 1 if ti.static(static_rigid_sim_config.is_backward): out_bw[0, i_d, i_b] = vec[i_d, i_b] @@ -4226,7 +4256,12 @@ def func_solve_mass_batched( if ti.static(not static_rigid_sim_config.is_backward) else (j_d_ + entities_info.dof_start[i_e]) ) - if j_d >= i_d + 1 and j_d < entity_dof_end: + j_d_valid = ( + True + if ti.static(not static_rigid_sim_config.is_backward) + else (j_d >= i_d + 1 and j_d < entity_dof_end) + ) + if j_d_valid: # Since we read out[j_d, i_b], and j_d > i_d, which means that out[j_d, i_b] is already # finalized at this point, we don't need to care about AD mutation rule. if ti.static(static_rigid_sim_config.is_backward): @@ -4234,7 +4269,7 @@ def func_solve_mass_batched( rigid_global_info.mass_mat_L[j_d, i_d, i_b] * out_bw[0, j_d, i_b] ) else: - out[i_d, i_b] += -(rigid_global_info.mass_mat_L[j_d, i_d, i_b] * out[j_d, i_b]) + out[i_d, i_b] -= rigid_global_info.mass_mat_L[j_d, i_d, i_b] * out[j_d, i_b] # Step 2: z = D^{-1} w for i_d_ in ( @@ -4247,7 +4282,8 @@ def func_solve_mass_batched( if ti.static(not static_rigid_sim_config.is_backward) else (i_d_ + entities_info.dof_start[i_e]) ) - if i_d < entity_dof_end: + i_d_valid = True if ti.static(not static_rigid_sim_config.is_backward) else (i_d < entity_dof_end) + if i_d_valid: if ti.static(static_rigid_sim_config.is_backward): out_bw[1, i_d, i_b] = out_bw[0, i_d, i_b] * rigid_global_info.mass_mat_D_inv[i_d, i_b] else: @@ -4264,7 +4300,8 @@ def func_solve_mass_batched( if ti.static(not static_rigid_sim_config.is_backward) else (i_d_ + entities_info.dof_start[i_e]) ) - if i_d < entity_dof_end: + i_d_valid = True if ti.static(not static_rigid_sim_config.is_backward) else (i_d < entity_dof_end) + if i_d_valid: curr_out = out[i_d, i_b] if ti.static(static_rigid_sim_config.is_backward): curr_out = out_bw[1, i_d, i_b] @@ -4279,10 +4316,16 @@ def func_solve_mass_batched( if ti.static(not static_rigid_sim_config.is_backward) else (j_d_ + entities_info.dof_start[i_e]) ) - if j_d < i_d: - curr_out += -(rigid_global_info.mass_mat_L[i_d, j_d, i_b] * out[j_d, i_b]) + j_d_valid = True if ti.static(not static_rigid_sim_config.is_backward) else (j_d < i_d) + if j_d_valid: + if ti.static(static_rigid_sim_config.is_backward): + curr_out += -(rigid_global_info.mass_mat_L[i_d, j_d, i_b] * out[j_d, i_b]) + else: + # Write directly to out[i_d, i_b] for speed up + out[i_d, i_b] -= rigid_global_info.mass_mat_L[i_d, j_d, i_b] * out[j_d, i_b] - out[i_d, i_b] = curr_out + if ti.static(static_rigid_sim_config.is_backward): + out[i_d, i_b] = curr_out @ti.func @@ -4896,11 +4939,14 @@ def func_COM_links( else ti.static(range(links_info.root_idx.shape[0])) ) ): - if i_l_ < ( - rigid_global_info.n_awake_links[i_b] - if ti.static(static_rigid_sim_config.use_hibernation) - else links_info.root_idx.shape[0] - ): + i_l_valid = True + if ti.static(static_rigid_sim_config.is_backward): + i_l_valid = i_l_ < ( + rigid_global_info.n_awake_links[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else links_info.root_idx.shape[0] + ) + if i_l_valid: i_l = ( rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ ) @@ -4924,11 +4970,14 @@ def func_COM_links( else ti.static(range(links_info.root_idx.shape[0])) ) ): - if i_l_ < ( - rigid_global_info.n_awake_links[i_b] - if ti.static(static_rigid_sim_config.use_hibernation) - else links_info.root_idx.shape[0] - ): + i_l_valid = True + if ti.static(static_rigid_sim_config.is_backward): + i_l_valid = i_l_ < ( + rigid_global_info.n_awake_links[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else links_info.root_idx.shape[0] + ) + if i_l_valid: i_l = ( rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ ) @@ -4965,11 +5014,14 @@ def func_COM_links( else ti.static(range(links_info.root_idx.shape[0])) ) ): - if i_l_ < ( - rigid_global_info.n_awake_links[i_b] - if ti.static(static_rigid_sim_config.use_hibernation) - else links_info.root_idx.shape[0] - ): + i_l_valid = True + if ti.static(static_rigid_sim_config.is_backward): + i_l_valid = i_l_ < ( + rigid_global_info.n_awake_links[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else links_info.root_idx.shape[0] + ) + if i_l_valid: i_l = ( rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ ) @@ -4995,11 +5047,14 @@ def func_COM_links( else ti.static(range(links_info.root_idx.shape[0])) ) ): - if i_l_ < ( - rigid_global_info.n_awake_links[i_b] - if ti.static(static_rigid_sim_config.use_hibernation) - else links_info.root_idx.shape[0] - ): + i_l_valid = True + if ti.static(static_rigid_sim_config.is_backward): + i_l_valid = i_l_ < ( + rigid_global_info.n_awake_links[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else links_info.root_idx.shape[0] + ) + if i_l_valid: i_l = ( rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ ) @@ -5024,11 +5079,14 @@ def func_COM_links( else ti.static(range(links_info.root_idx.shape[0])) ) ): - if i_l_ < ( - rigid_global_info.n_awake_links[i_b] - if ti.static(static_rigid_sim_config.use_hibernation) - else links_info.root_idx.shape[0] - ): + i_l_valid = True + if ti.static(static_rigid_sim_config.is_backward): + i_l_valid = i_l_ < ( + rigid_global_info.n_awake_links[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else links_info.root_idx.shape[0] + ) + if i_l_valid: i_l = ( rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ ) @@ -5068,11 +5126,14 @@ def func_COM_links( else ti.static(range(links_info.root_idx.shape[0])) ) ): - if i_l_ < ( - rigid_global_info.n_awake_links[i_b] - if ti.static(static_rigid_sim_config.use_hibernation) - else links_info.root_idx.shape[0] - ): + i_l_valid = True + if ti.static(static_rigid_sim_config.is_backward): + i_l_valid = i_l_ < ( + rigid_global_info.n_awake_links[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else links_info.root_idx.shape[0] + ) + if i_l_valid: i_l = ( rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ ) @@ -5112,7 +5173,12 @@ def func_COM_links( curr_i_j = 0 if ti.static(not static_rigid_sim_config.is_backward) else i_j_ next_i_j = 0 if ti.static(not static_rigid_sim_config.is_backward) else i_j_ + 1 - if i_j < links_info.joint_end[I_l]: + i_j_valid = ( + True + if ti.static(not static_rigid_sim_config.is_backward) + else (i_j < links_info.joint_end[I_l]) + ) + if i_j_valid: I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j ( @@ -5145,11 +5211,14 @@ def func_COM_links( else ti.static(range(links_info.root_idx.shape[0])) ) ): - if i_l_ < ( - rigid_global_info.n_awake_links[i_b] - if ti.static(static_rigid_sim_config.use_hibernation) - else links_info.root_idx.shape[0] - ): + i_l_valid = True + if ti.static(static_rigid_sim_config.is_backward): + i_l_valid = i_l_ < ( + rigid_global_info.n_awake_links[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else links_info.root_idx.shape[0] + ) + if i_l_valid: i_l = ( rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ ) @@ -5167,7 +5236,12 @@ def func_COM_links( else (i_j_ + links_info.joint_start[I_l]) ) - if i_j < links_info.joint_end[I_l]: + i_j_valid = ( + True + if ti.static(not static_rigid_sim_config.is_backward) + else (i_j < links_info.joint_end[I_l]) + ) + if i_j_valid: offset_pos = links_state.root_COM[i_l, i_b] - joints_state.xanchor[i_j, i_b] I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j joint_type = joints_info.type[I_j] @@ -5203,7 +5277,12 @@ def func_COM_links( else ti.static(range(static_rigid_sim_config.max_n_dofs_per_joint)) ): i_d = i_d_ if ti.static(not static_rigid_sim_config.is_backward) else (i_d_ + dof_start) - if i_d < joints_info.dof_end[I_j]: + i_d_valid = ( + True + if ti.static(not static_rigid_sim_config.is_backward) + else (i_d < joints_info.dof_end[I_j]) + ) + if i_d_valid: dofs_state.cdofvel_ang[i_d, i_b] = ( dofs_state.cdof_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] ) @@ -5361,8 +5440,8 @@ def func_forward_kinematics_entity( ): EPS = rigid_global_info.EPS[None] i_l = i_l_ if ti.static(not static_rigid_sim_config.is_backward) else (i_l_ + entities_info.link_start[i_e]) - - if i_l < entities_info.link_end[i_e]: + i_l_valid = True if ti.static(not static_rigid_sim_config.is_backward) else (i_l < entities_info.link_end[i_e]) + if i_l_valid: I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l links_state.pos_bw[i_l, 0, i_b] = links_info.pos[I_l] @@ -5385,7 +5464,10 @@ def func_forward_kinematics_entity( curr_i_j = 0 if ti.static(not static_rigid_sim_config.is_backward) else i_j_ next_i_j = 0 if ti.static(not static_rigid_sim_config.is_backward) else i_j_ + 1 - if i_j < links_info.joint_end[I_l]: + i_j_valid = ( + True if ti.static(not static_rigid_sim_config.is_backward) else (i_j < links_info.joint_end[I_l]) + ) + if i_j_valid: I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j joint_type = joints_info.type[I_j] q_start = joints_info.q_start[I_j] @@ -5508,8 +5590,8 @@ def func_forward_velocity_entity( else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) ): i_l = i_l_ if ti.static(not static_rigid_sim_config.is_backward) else (i_l_ + entities_info.link_start[i_e]) - - if i_l < entities_info.link_end[i_e]: + i_l_valid = True if ti.static(not static_rigid_sim_config.is_backward) else (i_l < entities_info.link_end[i_e]) + if i_l_valid: I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l n_joints = links_info.joint_end[I_l] - links_info.joint_start[I_l] @@ -5526,8 +5608,10 @@ def func_forward_velocity_entity( else ti.static(range(static_rigid_sim_config.max_n_joints_per_link)) ): i_j = i_j_ + links_info.joint_start[I_l] - - if i_j < links_info.joint_end[I_l]: + i_j_valid = ( + True if ti.static(not static_rigid_sim_config.is_backward) else (i_j < links_info.joint_end[I_l]) + ) + if i_j_valid: I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j joint_type = joints_info.type[I_j] q_start = joints_info.q_start[I_j] @@ -5605,7 +5689,12 @@ def func_forward_velocity_entity( else ti.static(range(static_rigid_sim_config.max_n_dofs_per_joint)) ): i_d = i_d_ if ti.static(not static_rigid_sim_config.is_backward) else (i_d_ + dof_start) - if i_d < joints_info.dof_end[I_j]: + i_d_valid = ( + True + if ti.static(not static_rigid_sim_config.is_backward) + else (i_d < joints_info.dof_end[I_j]) + ) + if i_d_valid: dofs_state.cdofd_ang[i_d, i_b], dofs_state.cdofd_vel[i_d, i_b] = gu.motion_cross_motion( links_state.cd_ang_bw[i_l, curr_i_j, i_b], links_state.cd_vel_bw[i_l, curr_i_j, i_b], @@ -5622,7 +5711,12 @@ def func_forward_velocity_entity( else ti.static(range(static_rigid_sim_config.max_n_dofs_per_joint)) ): i_d = i_d_ if ti.static(not static_rigid_sim_config.is_backward) else (i_d_ + dof_start) - if i_d < joints_info.dof_end[I_j]: + i_d_valid = ( + True + if ti.static(not static_rigid_sim_config.is_backward) + else (i_d < joints_info.dof_end[I_j]) + ) + if i_d_valid: # Backward pass requires atomic add if ti.static(static_rigid_sim_config.is_backward): links_state.cd_vel_bw[i_l, next_i_j, i_b] += ( @@ -5720,7 +5814,10 @@ def func_update_geoms( ) ): i_g = i_1 + entities_info.geom_start[i_e] if ti.static(static_rigid_sim_config.use_hibernation) else i_0 - if i_1 < (n_geoms if ti.static(static_rigid_sim_config.use_hibernation) else 1): + i_l_valid = True + if ti.static(static_rigid_sim_config.is_backward): + i_l_valid = i_1 < (n_geoms if ti.static(static_rigid_sim_config.use_hibernation) else 1) + if i_l_valid: if force_update_fixed_geoms or not geoms_info.is_fixed[i_g]: ( geoms_state.pos[i_g, i_b], @@ -6125,8 +6222,10 @@ def func_torque_and_passive_force( else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) ): i_l = i_l_ if ti.static(not static_rigid_sim_config.is_backward) else (i_l_ + entities_info.link_start[i_e]) - - if i_l < entities_info.link_end[i_e]: + i_l_valid = ( + True if ti.static(not static_rigid_sim_config.is_backward) else (i_l < entities_info.link_end[i_e]) + ) + if i_l_valid: I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l if links_info.n_dofs[I_l] > 0: i_j = links_info.joint_start[I_l] @@ -6143,8 +6242,12 @@ def func_torque_and_passive_force( if ti.static(not static_rigid_sim_config.is_backward) else (i_d_ + links_info.dof_start[I_l]) ) - - if i_d < links_info.dof_end[I_l]: + i_d_valid = ( + True + if ti.static(not static_rigid_sim_config.is_backward) + else (i_d < links_info.dof_end[I_l]) + ) + if i_d_valid: I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d force = gs.ti_float(0.0) if dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.FORCE: @@ -6245,7 +6348,12 @@ def func_torque_and_passive_force( else ti.static(range(1)) ) ): - if i_1 < (rigid_global_info.n_awake_dofs[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1): + i_1_valid = True + if ti.static(static_rigid_sim_config.is_backward): + i_1_valid = i_1 < ( + rigid_global_info.n_awake_dofs[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ) + if i_1_valid: i_d = ( rigid_global_info.awake_dofs[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) @@ -6276,9 +6384,12 @@ def func_torque_and_passive_force( else ti.static(range(1)) ) ): - if i_1 < ( - rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 - ): + i_1_valid = True + if ti.static(static_rigid_sim_config.is_backward): + i_1_valid = i_1 < ( + rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ) + if i_1_valid: i_l = ( rigid_global_info.awake_links[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) @@ -6301,7 +6412,8 @@ def func_torque_and_passive_force( if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_dofs_per_link)) ): - if j_d < dof_end: + j_d_valid = True if ti.static(not static_rigid_sim_config.is_backward) else (j_d < dof_end) + if j_d_valid: I_d = ( [dof_start + j_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) @@ -6344,9 +6456,12 @@ def func_update_acc( else ti.static(range(1)) ) ): - if i_1 < ( - rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 - ): + i_1_valid = True + if ti.static(static_rigid_sim_config.is_backward): + i_1_valid = i_1 < ( + rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ) + if i_1_valid: i_e = ( rigid_global_info.awake_entities[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) @@ -6364,7 +6479,12 @@ def func_update_acc( else (i_l_ + entities_info.link_start[i_e]) ) - if i_l < entities_info.link_end[i_e]: + i_l_valid = ( + True + if ti.static(not static_rigid_sim_config.is_backward) + else (i_l < entities_info.link_end[i_e]) + ) + if i_l_valid: I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l i_p = links_info.parent_idx[I_l] @@ -6394,7 +6514,12 @@ def func_update_acc( else (i_d_ + links_info.dof_start[I_l]) ) - if i_d < links_info.dof_end[I_l]: + i_d_valid = ( + True + if ti.static(not static_rigid_sim_config.is_backward) + else (i_d < links_info.dof_end[I_l]) + ) + if i_d_valid: # cacc = cacc_parent + cdofdot * qvel + cdof * qacc local_cdd_vel = dofs_state.cdofd_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] local_cdd_ang = dofs_state.cdofd_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] @@ -6450,9 +6575,12 @@ def func_update_force( else ti.static(range(1)) ) ): - if i_1 < ( - rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 - ): + i_1_valid = True + if ti.static(static_rigid_sim_config.is_backward): + i_1_valid = i_1 < ( + rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ) + if i_1_valid: i_l = ( rigid_global_info.awake_links[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) @@ -6505,9 +6633,12 @@ def func_update_force( else ti.static(range(1)) ) ): - if i_1 < ( - rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 - ): + i_1_valid = True + if ti.static(static_rigid_sim_config.is_backward): + i_1_valid = i_1 < ( + rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ) + if i_1_valid: i_e = ( rigid_global_info.awake_entities[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) @@ -6519,7 +6650,12 @@ def func_update_force( if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) ): - if i_l_ < entities_info.n_links[i_e]: + i_l_valid = ( + True + if ti.static(not static_rigid_sim_config.is_backward) + else (i_l_ < entities_info.n_links[i_e]) + ) + if i_l_valid: i_l = entities_info.link_end[i_e] - 1 - i_l_ I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l i_p = links_info.parent_idx[I_l] @@ -6595,9 +6731,12 @@ def func_bias_force( else ti.static(range(1)) ) ): - if i_1 < ( - rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 - ): + i_1_valid = True + if ti.static(static_rigid_sim_config.is_backward): + i_1_valid = i_1 < ( + rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ) + if i_1_valid: i_l = ( rigid_global_info.awake_links[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) @@ -6615,7 +6754,10 @@ def func_bias_force( if ti.static(not static_rigid_sim_config.is_backward) else (i_d_ + links_info.dof_start[I_l]) ) - if i_d < links_info.dof_end[I_l]: + i_d_valid = ( + True if ti.static(not static_rigid_sim_config.is_backward) else (i_d < links_info.dof_end[I_l]) + ) + if i_d_valid: dofs_state.qf_bias[i_d, i_b] = dofs_state.cdof_ang[i_d, i_b].dot( links_state.cfrc_ang[i_l, i_b] ) + dofs_state.cdof_vel[i_d, i_b].dot(links_state.cfrc_vel[i_l, i_b]) @@ -6683,9 +6825,12 @@ def func_compute_qacc( else ti.static(range(1)) ) ): - if i_1 < ( - rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 - ): + i_1_valid = True + if ti.static(static_rigid_sim_config.is_backward): + i_1_valid = i_1 < ( + rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ) + if i_1_valid: i_e = ( rigid_global_info.awake_entities[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) @@ -6698,7 +6843,12 @@ def func_compute_qacc( else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) ): i_d1 = entities_info.dof_start[i_e] + i_d1_ - if i_d1 < entities_info.dof_end[i_e]: + i_d1_valid = ( + True + if ti.static(not static_rigid_sim_config.is_backward) + else (i_d1 < entities_info.dof_end[i_e]) + ) + if i_d1_valid: dofs_state.acc[i_d1, i_b] = dofs_state.acc_smooth[i_d1, i_b] @@ -6731,7 +6881,12 @@ def func_integrate( else ti.static(range(1)) ) ): - if i_1 < (rigid_global_info.n_awake_dofs[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1): + i_1_valid = True + if ti.static(static_rigid_sim_config.is_backward): + i_1_valid = i_1 < ( + rigid_global_info.n_awake_dofs[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ) + if i_1_valid: i_d = ( rigid_global_info.awake_dofs[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) @@ -6768,9 +6923,12 @@ def func_integrate( else ti.static(range(1)) ) ): - if i_1 < ( - rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 - ): + i_1_valid = True + if ti.static(static_rigid_sim_config.is_backward): + i_1_valid = i_1 < ( + rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ) + if i_1_valid: i_l = ( rigid_global_info.awake_links[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) @@ -6840,7 +6998,8 @@ def func_integrate( else (ti.static(range(static_rigid_sim_config.max_n_qs_per_link))) ): j = q_start + j_ - if j < q_end: + j_valid = True if ti.static(not static_rigid_sim_config.is_backward) else (j < q_end) + if j_valid: rigid_global_info.qpos_next[j, i_b] = ( rigid_global_info.qpos[j, i_b] + dofs_state.vel_next[dof_start + j_, i_b] * rigid_global_info.substep_dt[None] From a5a3337a5b9d8cb98536dc87ef5d470707c1e717 Mon Sep 17 00:00:00 2001 From: Sanghyun Date: Fri, 28 Nov 2025 00:54:27 -0800 Subject: [PATCH 12/12] reverted index checking except func_solve_mass_batched --- .../solvers/rigid/rigid_solver_decomp.py | 341 +++++------------- 1 file changed, 99 insertions(+), 242 deletions(-) diff --git a/genesis/engine/solvers/rigid/rigid_solver_decomp.py b/genesis/engine/solvers/rigid/rigid_solver_decomp.py index cb56aeccb..6cc6b988b 100644 --- a/genesis/engine/solvers/rigid/rigid_solver_decomp.py +++ b/genesis/engine/solvers/rigid/rigid_solver_decomp.py @@ -3669,12 +3669,9 @@ def func_compute_mass_matrix( else ti.static(range(1)) ) ): - i_1_valid = True - if ti.static(static_rigid_sim_config.is_backward): - i_1_valid = i_1 < ( - rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 - ) - if i_1_valid: + if i_1 < ( + rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): i_l = ( rigid_global_info.awake_links[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) @@ -3708,12 +3705,9 @@ def func_compute_mass_matrix( else ti.static(range(1)) ) ): - i_1_valid = True - if ti.static(static_rigid_sim_config.is_backward): - i_1_valid = i_1 < ( - rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 - ) - if i_1_valid: + if i_1 < ( + rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): i_e = ( rigid_global_info.awake_entities[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) @@ -3725,10 +3719,7 @@ def func_compute_mass_matrix( if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) ): - i_valid = ( - True if ti.static(not static_rigid_sim_config.is_backward) else (i < entities_info.n_links[i_e]) - ) - if i_valid: + if i < entities_info.n_links[i_e]: i_l = entities_info.link_end[i_e] - 1 - i I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l i_p = links_info.parent_idx[I_l] @@ -3776,12 +3767,9 @@ def func_compute_mass_matrix( else ti.static(range(1)) ) ): - i_1_valid = True - if ti.static(static_rigid_sim_config.is_backward): - i_1_valid = i_1 < ( - rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 - ) - if i_1_valid: + if i_1 < ( + rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): i_l = ( rigid_global_info.awake_links[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) @@ -3797,10 +3785,8 @@ def func_compute_mass_matrix( i_d = ( i_d_ if ti.static(not static_rigid_sim_config.is_backward) else links_info.dof_start[I_l] + i_d_ ) - i_d_valid = ( - True if ti.static(not static_rigid_sim_config.is_backward) else (i_d < links_info.dof_end[I_l]) - ) - if i_d_valid: + + if i_d < links_info.dof_end[I_l]: dofs_state.f_ang[i_d, i_b], dofs_state.f_vel[i_d, i_b] = gu.inertial_mul( links_state.crb_pos[i_l, i_b], links_state.crb_inertial[i_l, i_b], @@ -3830,12 +3816,9 @@ def func_compute_mass_matrix( else ti.static(range(1)) ) ): - i_1_valid = True - if ti.static(static_rigid_sim_config.is_backward): - i_1_valid = i_1 < ( - rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 - ) - if i_1_valid: + if i_1 < ( + rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): i_e = ( rigid_global_info.awake_entities[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) @@ -3871,17 +3854,8 @@ def func_compute_mass_matrix( if ti.static(not static_rigid_sim_config.is_backward) else entities_info.dof_start[i_e] + j_d_ ) - i_d_valid = ( - True - if ti.static(not static_rigid_sim_config.is_backward) - else (i_d < entities_info.dof_end[i_e]) - ) - j_d_valid = ( - True - if ti.static(not static_rigid_sim_config.is_backward) - else (j_d < entities_info.dof_end[i_e]) - ) - if i_d_valid and j_d_valid: + + if i_d < entities_info.dof_end[i_e] and j_d < entities_info.dof_end[i_e]: rigid_global_info.mass_mat[i_d, j_d, i_b] = ( dofs_state.f_ang[i_d, i_b].dot(dofs_state.cdof_ang[j_d, i_b]) + dofs_state.f_vel[i_d, i_b].dot(dofs_state.cdof_vel[j_d, i_b]) @@ -4321,7 +4295,6 @@ def func_solve_mass_batched( if ti.static(static_rigid_sim_config.is_backward): curr_out += -(rigid_global_info.mass_mat_L[i_d, j_d, i_b] * out[j_d, i_b]) else: - # Write directly to out[i_d, i_b] for speed up out[i_d, i_b] -= rigid_global_info.mass_mat_L[i_d, j_d, i_b] * out[j_d, i_b] if ti.static(static_rigid_sim_config.is_backward): @@ -4939,14 +4912,11 @@ def func_COM_links( else ti.static(range(links_info.root_idx.shape[0])) ) ): - i_l_valid = True - if ti.static(static_rigid_sim_config.is_backward): - i_l_valid = i_l_ < ( - rigid_global_info.n_awake_links[i_b] - if ti.static(static_rigid_sim_config.use_hibernation) - else links_info.root_idx.shape[0] - ) - if i_l_valid: + if i_l_ < ( + rigid_global_info.n_awake_links[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else links_info.root_idx.shape[0] + ): i_l = ( rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ ) @@ -4970,14 +4940,11 @@ def func_COM_links( else ti.static(range(links_info.root_idx.shape[0])) ) ): - i_l_valid = True - if ti.static(static_rigid_sim_config.is_backward): - i_l_valid = i_l_ < ( - rigid_global_info.n_awake_links[i_b] - if ti.static(static_rigid_sim_config.use_hibernation) - else links_info.root_idx.shape[0] - ) - if i_l_valid: + if i_l_ < ( + rigid_global_info.n_awake_links[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else links_info.root_idx.shape[0] + ): i_l = ( rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ ) @@ -5014,14 +4981,11 @@ def func_COM_links( else ti.static(range(links_info.root_idx.shape[0])) ) ): - i_l_valid = True - if ti.static(static_rigid_sim_config.is_backward): - i_l_valid = i_l_ < ( - rigid_global_info.n_awake_links[i_b] - if ti.static(static_rigid_sim_config.use_hibernation) - else links_info.root_idx.shape[0] - ) - if i_l_valid: + if i_l_ < ( + rigid_global_info.n_awake_links[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else links_info.root_idx.shape[0] + ): i_l = ( rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ ) @@ -5047,14 +5011,11 @@ def func_COM_links( else ti.static(range(links_info.root_idx.shape[0])) ) ): - i_l_valid = True - if ti.static(static_rigid_sim_config.is_backward): - i_l_valid = i_l_ < ( - rigid_global_info.n_awake_links[i_b] - if ti.static(static_rigid_sim_config.use_hibernation) - else links_info.root_idx.shape[0] - ) - if i_l_valid: + if i_l_ < ( + rigid_global_info.n_awake_links[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else links_info.root_idx.shape[0] + ): i_l = ( rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ ) @@ -5079,14 +5040,11 @@ def func_COM_links( else ti.static(range(links_info.root_idx.shape[0])) ) ): - i_l_valid = True - if ti.static(static_rigid_sim_config.is_backward): - i_l_valid = i_l_ < ( - rigid_global_info.n_awake_links[i_b] - if ti.static(static_rigid_sim_config.use_hibernation) - else links_info.root_idx.shape[0] - ) - if i_l_valid: + if i_l_ < ( + rigid_global_info.n_awake_links[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else links_info.root_idx.shape[0] + ): i_l = ( rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ ) @@ -5126,14 +5084,11 @@ def func_COM_links( else ti.static(range(links_info.root_idx.shape[0])) ) ): - i_l_valid = True - if ti.static(static_rigid_sim_config.is_backward): - i_l_valid = i_l_ < ( - rigid_global_info.n_awake_links[i_b] - if ti.static(static_rigid_sim_config.use_hibernation) - else links_info.root_idx.shape[0] - ) - if i_l_valid: + if i_l_ < ( + rigid_global_info.n_awake_links[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else links_info.root_idx.shape[0] + ): i_l = ( rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ ) @@ -5173,12 +5128,7 @@ def func_COM_links( curr_i_j = 0 if ti.static(not static_rigid_sim_config.is_backward) else i_j_ next_i_j = 0 if ti.static(not static_rigid_sim_config.is_backward) else i_j_ + 1 - i_j_valid = ( - True - if ti.static(not static_rigid_sim_config.is_backward) - else (i_j < links_info.joint_end[I_l]) - ) - if i_j_valid: + if i_j < links_info.joint_end[I_l]: I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j ( @@ -5211,14 +5161,11 @@ def func_COM_links( else ti.static(range(links_info.root_idx.shape[0])) ) ): - i_l_valid = True - if ti.static(static_rigid_sim_config.is_backward): - i_l_valid = i_l_ < ( - rigid_global_info.n_awake_links[i_b] - if ti.static(static_rigid_sim_config.use_hibernation) - else links_info.root_idx.shape[0] - ) - if i_l_valid: + if i_l_ < ( + rigid_global_info.n_awake_links[i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else links_info.root_idx.shape[0] + ): i_l = ( rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ ) @@ -5236,12 +5183,7 @@ def func_COM_links( else (i_j_ + links_info.joint_start[I_l]) ) - i_j_valid = ( - True - if ti.static(not static_rigid_sim_config.is_backward) - else (i_j < links_info.joint_end[I_l]) - ) - if i_j_valid: + if i_j < links_info.joint_end[I_l]: offset_pos = links_state.root_COM[i_l, i_b] - joints_state.xanchor[i_j, i_b] I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j joint_type = joints_info.type[I_j] @@ -5277,12 +5219,7 @@ def func_COM_links( else ti.static(range(static_rigid_sim_config.max_n_dofs_per_joint)) ): i_d = i_d_ if ti.static(not static_rigid_sim_config.is_backward) else (i_d_ + dof_start) - i_d_valid = ( - True - if ti.static(not static_rigid_sim_config.is_backward) - else (i_d < joints_info.dof_end[I_j]) - ) - if i_d_valid: + if i_d < joints_info.dof_end[I_j]: dofs_state.cdofvel_ang[i_d, i_b] = ( dofs_state.cdof_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] ) @@ -5440,8 +5377,8 @@ def func_forward_kinematics_entity( ): EPS = rigid_global_info.EPS[None] i_l = i_l_ if ti.static(not static_rigid_sim_config.is_backward) else (i_l_ + entities_info.link_start[i_e]) - i_l_valid = True if ti.static(not static_rigid_sim_config.is_backward) else (i_l < entities_info.link_end[i_e]) - if i_l_valid: + + if i_l < entities_info.link_end[i_e]: I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l links_state.pos_bw[i_l, 0, i_b] = links_info.pos[I_l] @@ -5464,10 +5401,7 @@ def func_forward_kinematics_entity( curr_i_j = 0 if ti.static(not static_rigid_sim_config.is_backward) else i_j_ next_i_j = 0 if ti.static(not static_rigid_sim_config.is_backward) else i_j_ + 1 - i_j_valid = ( - True if ti.static(not static_rigid_sim_config.is_backward) else (i_j < links_info.joint_end[I_l]) - ) - if i_j_valid: + if i_j < links_info.joint_end[I_l]: I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j joint_type = joints_info.type[I_j] q_start = joints_info.q_start[I_j] @@ -5590,8 +5524,8 @@ def func_forward_velocity_entity( else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) ): i_l = i_l_ if ti.static(not static_rigid_sim_config.is_backward) else (i_l_ + entities_info.link_start[i_e]) - i_l_valid = True if ti.static(not static_rigid_sim_config.is_backward) else (i_l < entities_info.link_end[i_e]) - if i_l_valid: + + if i_l < entities_info.link_end[i_e]: I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l n_joints = links_info.joint_end[I_l] - links_info.joint_start[I_l] @@ -5608,10 +5542,8 @@ def func_forward_velocity_entity( else ti.static(range(static_rigid_sim_config.max_n_joints_per_link)) ): i_j = i_j_ + links_info.joint_start[I_l] - i_j_valid = ( - True if ti.static(not static_rigid_sim_config.is_backward) else (i_j < links_info.joint_end[I_l]) - ) - if i_j_valid: + + if i_j < links_info.joint_end[I_l]: I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j joint_type = joints_info.type[I_j] q_start = joints_info.q_start[I_j] @@ -5689,12 +5621,7 @@ def func_forward_velocity_entity( else ti.static(range(static_rigid_sim_config.max_n_dofs_per_joint)) ): i_d = i_d_ if ti.static(not static_rigid_sim_config.is_backward) else (i_d_ + dof_start) - i_d_valid = ( - True - if ti.static(not static_rigid_sim_config.is_backward) - else (i_d < joints_info.dof_end[I_j]) - ) - if i_d_valid: + if i_d < joints_info.dof_end[I_j]: dofs_state.cdofd_ang[i_d, i_b], dofs_state.cdofd_vel[i_d, i_b] = gu.motion_cross_motion( links_state.cd_ang_bw[i_l, curr_i_j, i_b], links_state.cd_vel_bw[i_l, curr_i_j, i_b], @@ -5711,12 +5638,7 @@ def func_forward_velocity_entity( else ti.static(range(static_rigid_sim_config.max_n_dofs_per_joint)) ): i_d = i_d_ if ti.static(not static_rigid_sim_config.is_backward) else (i_d_ + dof_start) - i_d_valid = ( - True - if ti.static(not static_rigid_sim_config.is_backward) - else (i_d < joints_info.dof_end[I_j]) - ) - if i_d_valid: + if i_d < joints_info.dof_end[I_j]: # Backward pass requires atomic add if ti.static(static_rigid_sim_config.is_backward): links_state.cd_vel_bw[i_l, next_i_j, i_b] += ( @@ -5814,10 +5736,7 @@ def func_update_geoms( ) ): i_g = i_1 + entities_info.geom_start[i_e] if ti.static(static_rigid_sim_config.use_hibernation) else i_0 - i_l_valid = True - if ti.static(static_rigid_sim_config.is_backward): - i_l_valid = i_1 < (n_geoms if ti.static(static_rigid_sim_config.use_hibernation) else 1) - if i_l_valid: + if i_1 < (n_geoms if ti.static(static_rigid_sim_config.use_hibernation) else 1): if force_update_fixed_geoms or not geoms_info.is_fixed[i_g]: ( geoms_state.pos[i_g, i_b], @@ -6222,10 +6141,8 @@ def func_torque_and_passive_force( else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) ): i_l = i_l_ if ti.static(not static_rigid_sim_config.is_backward) else (i_l_ + entities_info.link_start[i_e]) - i_l_valid = ( - True if ti.static(not static_rigid_sim_config.is_backward) else (i_l < entities_info.link_end[i_e]) - ) - if i_l_valid: + + if i_l < entities_info.link_end[i_e]: I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l if links_info.n_dofs[I_l] > 0: i_j = links_info.joint_start[I_l] @@ -6242,12 +6159,8 @@ def func_torque_and_passive_force( if ti.static(not static_rigid_sim_config.is_backward) else (i_d_ + links_info.dof_start[I_l]) ) - i_d_valid = ( - True - if ti.static(not static_rigid_sim_config.is_backward) - else (i_d < links_info.dof_end[I_l]) - ) - if i_d_valid: + + if i_d < links_info.dof_end[I_l]: I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d force = gs.ti_float(0.0) if dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.FORCE: @@ -6348,12 +6261,7 @@ def func_torque_and_passive_force( else ti.static(range(1)) ) ): - i_1_valid = True - if ti.static(static_rigid_sim_config.is_backward): - i_1_valid = i_1 < ( - rigid_global_info.n_awake_dofs[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 - ) - if i_1_valid: + if i_1 < (rigid_global_info.n_awake_dofs[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1): i_d = ( rigid_global_info.awake_dofs[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) @@ -6384,12 +6292,9 @@ def func_torque_and_passive_force( else ti.static(range(1)) ) ): - i_1_valid = True - if ti.static(static_rigid_sim_config.is_backward): - i_1_valid = i_1 < ( - rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 - ) - if i_1_valid: + if i_1 < ( + rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): i_l = ( rigid_global_info.awake_links[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) @@ -6412,8 +6317,7 @@ def func_torque_and_passive_force( if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_dofs_per_link)) ): - j_d_valid = True if ti.static(not static_rigid_sim_config.is_backward) else (j_d < dof_end) - if j_d_valid: + if j_d < dof_end: I_d = ( [dof_start + j_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) @@ -6456,12 +6360,9 @@ def func_update_acc( else ti.static(range(1)) ) ): - i_1_valid = True - if ti.static(static_rigid_sim_config.is_backward): - i_1_valid = i_1 < ( - rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 - ) - if i_1_valid: + if i_1 < ( + rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): i_e = ( rigid_global_info.awake_entities[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) @@ -6479,12 +6380,7 @@ def func_update_acc( else (i_l_ + entities_info.link_start[i_e]) ) - i_l_valid = ( - True - if ti.static(not static_rigid_sim_config.is_backward) - else (i_l < entities_info.link_end[i_e]) - ) - if i_l_valid: + if i_l < entities_info.link_end[i_e]: I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l i_p = links_info.parent_idx[I_l] @@ -6514,12 +6410,7 @@ def func_update_acc( else (i_d_ + links_info.dof_start[I_l]) ) - i_d_valid = ( - True - if ti.static(not static_rigid_sim_config.is_backward) - else (i_d < links_info.dof_end[I_l]) - ) - if i_d_valid: + if i_d < links_info.dof_end[I_l]: # cacc = cacc_parent + cdofdot * qvel + cdof * qacc local_cdd_vel = dofs_state.cdofd_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] local_cdd_ang = dofs_state.cdofd_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] @@ -6575,12 +6466,9 @@ def func_update_force( else ti.static(range(1)) ) ): - i_1_valid = True - if ti.static(static_rigid_sim_config.is_backward): - i_1_valid = i_1 < ( - rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 - ) - if i_1_valid: + if i_1 < ( + rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): i_l = ( rigid_global_info.awake_links[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) @@ -6633,12 +6521,9 @@ def func_update_force( else ti.static(range(1)) ) ): - i_1_valid = True - if ti.static(static_rigid_sim_config.is_backward): - i_1_valid = i_1 < ( - rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 - ) - if i_1_valid: + if i_1 < ( + rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): i_e = ( rigid_global_info.awake_entities[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) @@ -6650,12 +6535,7 @@ def func_update_force( if ti.static(not static_rigid_sim_config.is_backward) else ti.static(range(static_rigid_sim_config.max_n_links_per_entity)) ): - i_l_valid = ( - True - if ti.static(not static_rigid_sim_config.is_backward) - else (i_l_ < entities_info.n_links[i_e]) - ) - if i_l_valid: + if i_l_ < entities_info.n_links[i_e]: i_l = entities_info.link_end[i_e] - 1 - i_l_ I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l i_p = links_info.parent_idx[I_l] @@ -6731,12 +6611,9 @@ def func_bias_force( else ti.static(range(1)) ) ): - i_1_valid = True - if ti.static(static_rigid_sim_config.is_backward): - i_1_valid = i_1 < ( - rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 - ) - if i_1_valid: + if i_1 < ( + rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): i_l = ( rigid_global_info.awake_links[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) @@ -6754,10 +6631,7 @@ def func_bias_force( if ti.static(not static_rigid_sim_config.is_backward) else (i_d_ + links_info.dof_start[I_l]) ) - i_d_valid = ( - True if ti.static(not static_rigid_sim_config.is_backward) else (i_d < links_info.dof_end[I_l]) - ) - if i_d_valid: + if i_d < links_info.dof_end[I_l]: dofs_state.qf_bias[i_d, i_b] = dofs_state.cdof_ang[i_d, i_b].dot( links_state.cfrc_ang[i_l, i_b] ) + dofs_state.cdof_vel[i_d, i_b].dot(links_state.cfrc_vel[i_l, i_b]) @@ -6825,12 +6699,9 @@ def func_compute_qacc( else ti.static(range(1)) ) ): - i_1_valid = True - if ti.static(static_rigid_sim_config.is_backward): - i_1_valid = i_1 < ( - rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 - ) - if i_1_valid: + if i_1 < ( + rigid_global_info.n_awake_entities[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): i_e = ( rigid_global_info.awake_entities[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) @@ -6843,12 +6714,7 @@ def func_compute_qacc( else ti.static(range(static_rigid_sim_config.max_n_dofs_per_entity)) ): i_d1 = entities_info.dof_start[i_e] + i_d1_ - i_d1_valid = ( - True - if ti.static(not static_rigid_sim_config.is_backward) - else (i_d1 < entities_info.dof_end[i_e]) - ) - if i_d1_valid: + if i_d1 < entities_info.dof_end[i_e]: dofs_state.acc[i_d1, i_b] = dofs_state.acc_smooth[i_d1, i_b] @@ -6881,12 +6747,7 @@ def func_integrate( else ti.static(range(1)) ) ): - i_1_valid = True - if ti.static(static_rigid_sim_config.is_backward): - i_1_valid = i_1 < ( - rigid_global_info.n_awake_dofs[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 - ) - if i_1_valid: + if i_1 < (rigid_global_info.n_awake_dofs[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1): i_d = ( rigid_global_info.awake_dofs[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) @@ -6923,12 +6784,9 @@ def func_integrate( else ti.static(range(1)) ) ): - i_1_valid = True - if ti.static(static_rigid_sim_config.is_backward): - i_1_valid = i_1 < ( - rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 - ) - if i_1_valid: + if i_1 < ( + rigid_global_info.n_awake_links[i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 1 + ): i_l = ( rigid_global_info.awake_links[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) @@ -6998,8 +6856,7 @@ def func_integrate( else (ti.static(range(static_rigid_sim_config.max_n_qs_per_link))) ): j = q_start + j_ - j_valid = True if ti.static(not static_rigid_sim_config.is_backward) else (j < q_end) - if j_valid: + if j < q_end: rigid_global_info.qpos_next[j, i_b] = ( rigid_global_info.qpos[j, i_b] + dofs_state.vel_next[dof_start + j_, i_b] * rigid_global_info.substep_dt[None]