Skip to content

Conversation

@SonSang
Copy link
Collaborator

@SonSang SonSang commented Oct 7, 2025

Description

This PR has following changes:

  • Unify hibernation / non-hibernation code in rigid_solver_decomp.py
  • Unify dynamic and static inner loops based on is_backward (static when is_backward=True)
  • Differentiable formulations for functions in rigid_solver_decomp.py

We do not add additional dimension to the states in the rigid body simulation for the frames, because it incurs too much code change.

Related Issue

Resolves Genesis-Embodied-AI/Genesis#

Motivation and Context

This is the part of the process to make the rigid body simulation to be differentiable.

How Has This Been / Can This Be Tested?

There is a unit test to verify by solving an optimization problem tests/test_grad.py::test_differentiable_rigid.

Screenshots (if appropriate):

Checklist:

  • I read the CONTRIBUTING document.
  • I followed the Submitting Code Changes section of CONTRIBUTING document.
  • I tagged the title correctly (including BUG FIX/FEATURE/MISC/BREAKING)
  • I updated the documentation accordingly or no change is needed.
  • I tested my changes and added instructions on how to test it for reviewers.
  • I have added tests to cover my changes.
  • All new and existing tests passed.

@SonSang SonSang changed the title [FEATURE] WIP: Differentiable forward dynamics for rigid body simulation (minimal fix) [FEATURE] Differentiable forward dynamics for rigid body simulation (minimal fix) Oct 18, 2025
@github-actions
Copy link

⚠️ Benchmark Regression Detected
Baselines considered: 5 commits

Thresholds: runtime ≤ −10%, compile ≥ +10%

Runtime FPS

status benchmark_id current FPS baseline FPS Δ FPS
batch_size=2048-constraint_solver=Newton-enable_mujoco_compatibility=False-env=box_pyramid#5-gjk_collision=False-use_contact_island=False 45,712 30,105 +51.84%
batch_size=2048-constraint_solver=Newton-enable_mujoco_compatibility=False-env=box_pyramid#5-gjk_collision=True-use_contact_island=False 19,600 20,663 -5.14%
batch_size=2048-constraint_solver=Newton-enable_mujoco_compatibility=True-env=box_pyramid#5-gjk_collision=False-use_contact_island=False 16,159 16,819 -3.92%
🔴 batch_size=30000-constraint_solver=CG-env=anymal_c-gjk_collision=False-use_contact_island=False 2,356,744 11,879,586 -80.16%
🔴 batch_size=30000-constraint_solver=CG-env=anymal_c-gjk_collision=True-use_contact_island=False 2,373,289 11,974,315 -80.18%
🔴 batch_size=30000-constraint_solver=CG-env=batched_franka-gjk_collision=False-use_contact_island=False 3,650,183 11,971,721 -69.51%
🔴 batch_size=30000-constraint_solver=CG-env=batched_franka-gjk_collision=True-use_contact_island=False 3,560,228 12,006,226 -70.35%
🔴 batch_size=30000-constraint_solver=Newton-env=anymal_c-gjk_collision=False-use_contact_island=False 2,353,910 11,966,930 -80.33%
🔴 batch_size=30000-constraint_solver=Newton-env=anymal_c-gjk_collision=True-use_contact_island=False 2,331,798 11,929,539 -80.45%
🔴 batch_size=30000-constraint_solver=Newton-env=batched_franka-gjk_collision=False-use_contact_island=False 3,576,319 12,001,647 -70.20%
🔴 batch_size=30000-constraint_solver=Newton-env=batched_franka-gjk_collision=True-use_contact_island=False 3,486,764 11,965,350 -70.86%
🔴 batch_size=8192-constraint_solver=CG-env=cube#10-gjk_collision=False-use_contact_island=False 519,700 1,597,992 -67.48%
🔴 batch_size=8192-constraint_solver=CG-env=cube#10-gjk_collision=True-use_contact_island=False 519,931 1,596,952 -67.44%
🔴 batch_size=8192-constraint_solver=Newton-env=cube#10-gjk_collision=False-use_contact_island=False 282,152 449,266 -37.20%
🔴 batch_size=8192-constraint_solver=Newton-env=cube#10-gjk_collision=True-use_contact_island=False 276,196 448,622 -38.43%

Compile Time

status benchmark_id current compile baseline compile Δ compile
batch_size=2048-constraint_solver=Newton-enable_mujoco_compatibility=False-env=box_pyramid#5-gjk_collision=False-use_contact_island=False 32 33 -3.03%
batch_size=2048-constraint_solver=Newton-enable_mujoco_compatibility=False-env=box_pyramid#5-gjk_collision=True-use_contact_island=False 37 36 +2.78%
🔴 batch_size=2048-constraint_solver=Newton-enable_mujoco_compatibility=True-env=box_pyramid#5-gjk_collision=False-use_contact_island=False 31 28 +10.71%
batch_size=30000-constraint_solver=CG-env=anymal_c-gjk_collision=False-use_contact_island=False 41 38 +7.89%
batch_size=30000-constraint_solver=CG-env=anymal_c-gjk_collision=True-use_contact_island=False 41 40 +2.50%
batch_size=30000-constraint_solver=CG-env=batched_franka-gjk_collision=False-use_contact_island=False 41 39 +5.13%
batch_size=30000-constraint_solver=CG-env=batched_franka-gjk_collision=True-use_contact_island=False 39 41 -4.88%
batch_size=30000-constraint_solver=Newton-env=anymal_c-gjk_collision=False-use_contact_island=False 41 39 +5.13%
batch_size=30000-constraint_solver=Newton-env=anymal_c-gjk_collision=True-use_contact_island=False 42 40 +5.00%
🔴 batch_size=30000-constraint_solver=Newton-env=batched_franka-gjk_collision=False-use_contact_island=False 42 38 +10.53%
batch_size=30000-constraint_solver=Newton-env=batched_franka-gjk_collision=True-use_contact_island=False 42 41 +2.44%
batch_size=8192-constraint_solver=CG-env=cube#10-gjk_collision=False-use_contact_island=False 42 41 +2.44%
batch_size=8192-constraint_solver=CG-env=cube#10-gjk_collision=True-use_contact_island=False 41 39 +5.13%
batch_size=8192-constraint_solver=Newton-env=cube#10-gjk_collision=False-use_contact_island=False 43 41 +4.88%
batch_size=8192-constraint_solver=Newton-env=cube#10-gjk_collision=True-use_contact_island=False 40 41 -2.44%

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is noslip feature supported when enabling gradient computation? If not, you should raise an exception at init.

Comment on lines 38 to 39
constraint_state.Mgrad,
constraint_state.Mgrad, # this will not be used anyway because is_backward is False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After checking, it seems that another variable is specified when is_backward=True, so here if I understand correctly you are setting it to anything because it will not be used in practice. This does work in practice but I don't like it much... What about defining some extra free variable PLACEHOLDER in array_class (0D taichi tensor of type array_class.V) that we could use everywhere an argument is not used? This would clarify the intend and avoid any mistake because you cannot do much with such tensor.

@duburcqa duburcqa changed the title [FEATURE] Differentiable forward dynamics for rigid body simulation (minimal fix) [FEATURE] Differentiable forward dynamics for rigid body sim. Oct 23, 2025
@github-actions
Copy link

⚠️ Benchmark Regression Detected
Baselines considered: 5 commits

Thresholds: runtime ≤ −10%, compile ≥ +10%

Runtime FPS

status benchmark_id current FPS baseline FPS Δ FPS
batch_size=2048-constraint_solver=Newton-enable_mujoco_compatibility=False-env=box_pyramid#5-gjk_collision=False-use_contact_island=False 43,946 30,086 +46.07%
batch_size=2048-constraint_solver=Newton-enable_mujoco_compatibility=False-env=box_pyramid#5-gjk_collision=True-use_contact_island=False 19,992 20,513 -2.54%
batch_size=2048-constraint_solver=Newton-enable_mujoco_compatibility=True-env=box_pyramid#5-gjk_collision=False-use_contact_island=False 16,605 16,756 -0.90%
🔴 batch_size=30000-constraint_solver=CG-env=anymal_c-gjk_collision=False-use_contact_island=False 2,059,965 13,411,484 -84.64%
🔴 batch_size=30000-constraint_solver=CG-env=anymal_c-gjk_collision=True-use_contact_island=False 2,331,107 13,402,335 -82.61%
🔴 batch_size=30000-constraint_solver=CG-env=batched_franka-gjk_collision=False-use_contact_island=False 3,304,798 18,408,303 -82.05%
🔴 batch_size=30000-constraint_solver=CG-env=batched_franka-gjk_collision=True-use_contact_island=False 4,030,974 18,336,977 -78.02%
ℹ️ batch_size=30000-constraint_solver=CG-env=random-gjk_collision=False-use_contact_island=False 946,802 nan +nan%
ℹ️ batch_size=30000-constraint_solver=CG-env=random-gjk_collision=True-use_contact_island=False 843,972 nan +nan%
🔴 batch_size=30000-constraint_solver=Newton-env=anymal_c-gjk_collision=False-use_contact_island=False 1,999,001 13,036,483 -84.67%
🔴 batch_size=30000-constraint_solver=Newton-env=anymal_c-gjk_collision=True-use_contact_island=False 2,081,023 13,004,375 -84.00%
🔴 batch_size=30000-constraint_solver=Newton-env=batched_franka-gjk_collision=False-use_contact_island=False 3,109,363 18,417,110 -83.12%
🔴 batch_size=30000-constraint_solver=Newton-env=batched_franka-gjk_collision=True-use_contact_island=False 3,256,759 18,607,265 -82.50%
ℹ️ batch_size=30000-constraint_solver=Newton-env=random-gjk_collision=False-use_contact_island=False 1,084,131 nan +nan%
ℹ️ batch_size=30000-constraint_solver=Newton-env=random-gjk_collision=True-use_contact_island=False 1,030,587 nan +nan%
🔴 batch_size=8192-constraint_solver=CG-env=cube#10-gjk_collision=False-use_contact_island=False 539,180 1,614,643 -66.61%
🔴 batch_size=8192-constraint_solver=CG-env=cube#10-gjk_collision=True-use_contact_island=False 537,193 1,617,820 -66.80%
🔴 batch_size=8192-constraint_solver=Newton-env=cube#10-gjk_collision=False-use_contact_island=False 287,082 450,107 -36.22%
🔴 batch_size=8192-constraint_solver=Newton-env=cube#10-gjk_collision=True-use_contact_island=False 289,511 449,772 -35.63%

Compile Time

status benchmark_id current compile baseline compile Δ compile
batch_size=2048-constraint_solver=Newton-enable_mujoco_compatibility=False-env=box_pyramid#5-gjk_collision=False-use_contact_island=False 35 37 -5.41%
batch_size=2048-constraint_solver=Newton-enable_mujoco_compatibility=False-env=box_pyramid#5-gjk_collision=True-use_contact_island=False 35 37 -5.41%
batch_size=2048-constraint_solver=Newton-enable_mujoco_compatibility=True-env=box_pyramid#5-gjk_collision=False-use_contact_island=False 31 31 +0.00%
batch_size=30000-constraint_solver=CG-env=anymal_c-gjk_collision=False-use_contact_island=False 42 41 +2.44%
batch_size=30000-constraint_solver=CG-env=anymal_c-gjk_collision=True-use_contact_island=False 41 40 +2.50%
batch_size=30000-constraint_solver=CG-env=batched_franka-gjk_collision=False-use_contact_island=False 42 42 +0.00%
batch_size=30000-constraint_solver=CG-env=batched_franka-gjk_collision=True-use_contact_island=False 42 42 +0.00%
ℹ️ batch_size=30000-constraint_solver=CG-env=random-gjk_collision=False-use_contact_island=False 42 nan +nan%
ℹ️ batch_size=30000-constraint_solver=CG-env=random-gjk_collision=True-use_contact_island=False 43 nan +nan%
batch_size=30000-constraint_solver=Newton-env=anymal_c-gjk_collision=False-use_contact_island=False 42 39 +7.69%
batch_size=30000-constraint_solver=Newton-env=anymal_c-gjk_collision=True-use_contact_island=False 42 40 +5.00%
batch_size=30000-constraint_solver=Newton-env=batched_franka-gjk_collision=False-use_contact_island=False 42 40 +5.00%
batch_size=30000-constraint_solver=Newton-env=batched_franka-gjk_collision=True-use_contact_island=False 43 42 +2.38%
ℹ️ batch_size=30000-constraint_solver=Newton-env=random-gjk_collision=False-use_contact_island=False 42 nan +nan%
ℹ️ batch_size=30000-constraint_solver=Newton-env=random-gjk_collision=True-use_contact_island=False 43 nan +nan%
batch_size=8192-constraint_solver=CG-env=cube#10-gjk_collision=False-use_contact_island=False 41 39 +5.13%
batch_size=8192-constraint_solver=CG-env=cube#10-gjk_collision=True-use_contact_island=False 40 39 +2.56%
batch_size=8192-constraint_solver=Newton-env=cube#10-gjk_collision=False-use_contact_island=False 39 39 +0.00%
batch_size=8192-constraint_solver=Newton-env=cube#10-gjk_collision=True-use_contact_island=False 41 40 +2.50%

@Satvik1701
Copy link

Satvik1701 commented Nov 8, 2025

Hey! I was just wondering if this PR is close to being merged soon, since the Heterogeneous Simulation PR is dependent on this. Thanks! PR: #1589

@SonSang
Copy link
Collaborator Author

SonSang commented Nov 8, 2025

Hey! I was just wondering if this PR is close to being merged soon, since the Heterogeneous Simulation PR is dependent on this. Thanks! PR: #1589

Hey, sorry for being late. This PR needs some polishing to pass some benchmark tests, and I'm working on it. Sorry for being late, I'll try to wrap up as soon as possible.

@Satvik1701
Copy link

Hey! I was just wondering if this PR is close to being merged soon, since the Heterogeneous Simulation PR is dependent on this. Thanks! PR: #1589

Hey, sorry for being late. This PR needs some polishing to pass some benchmark tests, and I'm working on it. Sorry for being late, I'll try to wrap up as soon as possible.

Thanks so much! Yeah sorry for being pushy but do you have a rough timeline for this because I'm looking to train with the heterogeneous simulation and might reprioritize somethings.

@SonSang
Copy link
Collaborator Author

SonSang commented Nov 8, 2025

Hey! I was just wondering if this PR is close to being merged soon, since the Heterogeneous Simulation PR is dependent on this. Thanks! PR: #1589

Hey, sorry for being late. This PR needs some polishing to pass some benchmark tests, and I'm working on it. Sorry for being late, I'll try to wrap up as soon as possible.

Thanks so much! Yeah sorry for being pushy but do you have a rough timeline for this because I'm looking to train with the heterogeneous simulation and might reprioritize somethings.

It's hard to say, but I think at least a week is needed (because code review is needed again for merging). I'll let you know if I have a better estimate.

@SamratSahoo
Copy link

following up on this - any updates on the eta? also looking to use heterogeneous simulations (#1589) which depends on this

@SonSang
Copy link
Collaborator Author

SonSang commented Nov 18, 2025

following up on this - any updates on the eta? also looking to use heterogeneous simulations (#1589) which depends on this

Hi, this PR will be merged tomorrow if code review goes well. Based on the review, it could be delayed a little, but I'll try to finish up until this weekend.

@github-actions
Copy link

⚠️ Benchmark Regression Detected
➡️ Report

@github-actions
Copy link

⚠️ Benchmark Regression Detected
➡️ Report

Comment on lines 2117 to 2135
def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, unsafe=False, update_tgt=True):
"""
Set quaternion of the entity's base link.
Parameters
----------
quat : array_like
The quaternion to set.
relative : bool, optional
Whether the quaternion to set is absolute or relative to the initial (not current!) quaternion. Defaults to
False.
zero_velocity : bool, optional
Whether to zero the velocity of all the entity's dofs. Defaults to True. This is a safety measure after a
sudden change in entity pose.
envs_idx : None | array_like, optional
The indices of the environments. If None, all environments will be considered. Defaults to None.
"""
# Save in [tgt] for backward pass
if update_tgt:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update_tgt seems an internal feature. I would recommend calling self.update_tgt manually right before set_quat to avoid leaking this option to the user.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After checking twice, keep self.update_tgt call inside this function but rather use a global state variable self._update_tgt to track whether target must be updated or not, without exposing this option?


@gs.assert_built
def set_quat_grad(self, envs_idx, relative, unsafe, quat_grad):
tmp_quat_grad = quat_grad.unsqueeze(-2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get it. Why don't you pass quat_grad.data directly?

i_e = self.contact_island.entity_id[i_e_, i_b]
self._solver.mass_mat_mask[i_e_, i_b] = True
self._solver._func_solve_mass_batched(self.grad, self.Mgrad, i_b)
self._solver._func_solve_mass_batched(self.grad, self.Mgrad, None, i_b)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think passing none is supported.

@@ -250,10 +264,60 @@ def build(self):
enable_joint_limit=getattr(self, "_enable_joint_limit", False),
box_box_detection=getattr(self, "_box_box_detection", True),
sparse_solve=getattr(self._options, "sparse_solve", False),
integrator=getattr(self, "_integrator", gs.integrator.implicitfast),
integrator=getattr(self, "_integrator", gs.integrator.approximate_implicitfast),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird to include this "bug fix" in this PR but OK.

Comment on lines 280 to 284
if isinstance(self.sim.coupler, SAPCoupler):
gs.raise_exception("SAPCoupler is not supported yet when requires_grad is True.")

if isinstance(self.sim.coupler, IPCCoupler):
gs.raise_exception("IPCCoupler is not supported yet when requires_grad is True.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if isinstance(self.sim.coupler, (SAPCoupler, IPCCoupler)):
    gs.raise_exception(f"{type(self.sim.coupler).__name__} is not supported yet when requires_grad is True.")
    

Comment on lines 289 to 337
# Add terms for static inner loops, use 0 if not requires_grad to avoid re-compilation
self._static_rigid_sim_config.max_n_links_per_entity = (
getattr(self, "_max_n_links_per_entity", 0) if self._static_rigid_sim_config.requires_grad else 0
)
self._static_rigid_sim_config.max_n_joints_per_link = (
getattr(self, "_max_n_joints_per_link", 0) if self._static_rigid_sim_config.requires_grad else 0
)
self._static_rigid_sim_config.max_n_dofs_per_joint = (
getattr(self, "_max_n_dofs_per_joint", 0) if self._static_rigid_sim_config.requires_grad else 0
)
self._static_rigid_sim_config.max_n_qs_per_link = (
getattr(self, "_max_n_qs_per_link", 0) if self._static_rigid_sim_config.requires_grad else 0
)
self._static_rigid_sim_config.max_n_dofs_per_entity = (
getattr(self, "_max_n_dofs_per_entity", 0) if self._static_rigid_sim_config.requires_grad else 0
)
self._static_rigid_sim_config.max_n_dofs_per_link = (
getattr(self, "_max_n_dofs_per_link", 0) if self._static_rigid_sim_config.requires_grad else 0
)
self._static_rigid_sim_config.max_n_geoms_per_entity = (
getattr(self, "_max_n_geoms_per_entity", 0) if self._static_rigid_sim_config.requires_grad else 0
)
self._static_rigid_sim_config.max_n_awake_links = (
getattr(self, "_n_links", 0) if self._static_rigid_sim_config.requires_grad else 0
)
self._static_rigid_sim_config.max_n_awake_entities = (
getattr(self, "_n_entities", 0) if self._static_rigid_sim_config.requires_grad else 0
)
self._static_rigid_sim_config.max_n_awake_dofs = (
getattr(self, "_n_dofs", 0) if self._static_rigid_sim_config.requires_grad else 0
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stop using getattr, it is a bad practice.

@@ -291,7 +361,7 @@ def build(self):
self._init_constraint_solver()

self._init_invweight_and_meaninertia(force_update=False)
self._func_update_geoms(self._scene._envs_idx, force_update_fixed_geoms=True)
self._func_update_geoms(self._scene._envs_idx, force_update_fixed_geoms=True, is_backward=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should just make is_backward=False the default value. It would be easier.

Comment on lines 1227 to 1232
curr_qpos = self._rigid_adjoint_cache.qpos.to_numpy()[f]
curr_dofs_vel = self._rigid_adjoint_cache.dofs_vel.to_numpy()[f]
curr_dofs_acc = self._rigid_adjoint_cache.dofs_acc.to_numpy()[f]
self._rigid_global_info.qpos.from_numpy(curr_qpos)
self.dofs_state.vel.from_numpy(curr_dofs_vel)
self.dofs_state.acc.from_numpy(curr_dofs_acc)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is too inefficient. We need to find a better way.

Comment on lines 1283 to 1286
qpos_grad = self._rigid_global_info.qpos.grad.to_numpy()
dofs_vel_grad = self.dofs_state.vel.grad.to_numpy()
if np.isnan(qpos_grad).sum() > 0 or np.isnan(dofs_vel_grad).sum() > 0:
gs.raise_exception(f"Nan grad in qpos or dofs_vel found at step {self._sim.cur_step_global}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same. very inefficient. We cannot afford it I think.

Comment on lines 1288 to 1292
kernel_copy_next_to_curr.grad(
dofs_state=self.dofs_state,
rigid_global_info=self._rigid_global_info,
static_rigid_sim_config=self._static_rigid_sim_config,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All these small kernels are dominated by the overhead. You should rather make all of this a gigantic kernel.

Comment on lines 1429 to 1434
qpos_grad = gs.zeros_like(state.qpos)
dofs_vel_grad = gs.zeros_like(state.dofs_vel)
links_pos_grad = gs.zeros_like(state.links_pos)
links_quat_grad = gs.zeros_like(state.links_quat)

if state.qpos.grad is not None:
qpos_grad = state.qpos.grad

if state.dofs_vel.grad is not None:
dofs_vel_grad = state.dofs_vel.grad

if state.links_pos.grad is not None:
links_pos_grad = state.links_pos.grad

if state.links_quat.grad is not None:
links_quat_grad = state.links_quat.grad
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid skipping all these lines, you are diluting information.

@github-actions
Copy link

⚠️ Benchmark Regression Detected
➡️ Report

Comment on lines 1590 to 1592
self._ckpt[ckpt_name]["qpos"] = self._rigid_adjoint_cache.qpos.to_numpy()
self._ckpt[ckpt_name]["dofs_vel"] = self._rigid_adjoint_cache.dofs_vel.to_numpy()
self._ckpt[ckpt_name]["dofs_acc"] = self._rigid_adjoint_cache.dofs_acc.to_numpy()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is part of the hot path, avoid using to_numpy and prefer ti_to_numpy

Comment on lines 1713 to 1814
@@ -1468,6 +1806,15 @@ def _sanitize_2D_io_variables(
gs.raise_exception("Expecting 2D input tensor.")
return tensor, _inputs_idx, envs_idx

def _sanitize_2D_io_variables_grad(
self,
grad_after_sanitization,
grad_before_sanitization,
):
if grad_after_sanitization.shape != grad_before_sanitization.shape:
gs.raise_exception("Shape of grad_after_sanitization and grad_before_sanitization do not match.")
return grad_after_sanitization

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not "sanitize" anything. It "validates".

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides, these are static functions. So I would just move them outside this class (right before it, as free function).

Comment on lines 2264 to 2274
velocity_grad_, dofs_idx, envs_idx = self._sanitize_1D_io_variables(
velocity_grad, dofs_idx, self.n_dofs, envs_idx, skip_allocation=True, unsafe=unsafe
)
if self.n_envs == 0:
velocity_grad_ = velocity_grad_.unsqueeze(0)
kernel_set_dofs_velocity_grad(
velocity_grad_, dofs_idx, envs_idx, self.dofs_state, self._static_rigid_sim_config
)
if self.n_envs == 0:
velocity_grad_ = velocity_grad_.squeeze(0)
velocity_grad.data = self._sanitize_1D_io_variables_grad(velocity_grad_, velocity_grad)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't you just pass velocity_grad.data directly to the kernel?

        velocity_grad.data, dofs_idx, envs_idx = self._sanitize_1D_io_variables(
            velocity_grad, dofs_idx, self.n_dofs, envs_idx, skip_allocation=True, unsafe=unsafe
        )
        velocity_grad_ = velocity_grad.data
        if self.n_envs == 0:
             velocity_grad_ = velocity_grad_.unsqueeze(0)
        kernel_set_dofs_velocity_grad(
            velocity_grad_, dofs_idx, envs_idx, self.dofs_state, self._static_rigid_sim_config
        )

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this validation really needed?

rigid_global_info: array_class.RigidGlobalInfo,
static_rigid_sim_config: ti.template(),
):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to skip line here.

Comment on lines 7345 to 7275
@@ -6206,6 +7418,7 @@ def kernel_set_state(
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_d, i_b_ in ti.ndrange(n_dofs, envs_idx.shape[0]):
dofs_state.vel[i_d, envs_idx[i_b_]] = dofs_vel[envs_idx[i_b_], i_d]
dofs_state.acc[i_d, envs_idx[i_b_]] = dofs_acc[envs_idx[i_b_], i_d]
dofs_state.ctrl_force[i_d, envs_idx[i_b_]] = gs.ti_float(0.0)
dofs_state.ctrl_mode[i_d, envs_idx[i_b_]] = gs.CTRL_MODE.FORCE
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should move adding acc to rigid state in another PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new unit test in this PR fails if we don't include it here. Can you reconsider it?

Comment on lines 7157 to 7212
@ti.kernel(fastcache=gs.use_fastcache)
def kernel_copy_cartesian_space(
src_dofs_state: array_class.DofsState,
src_links_state: array_class.LinksState,
src_joints_state: array_class.JointsState,
src_geoms_state: array_class.GeomsState,
dst_dofs_state: array_class.DofsState,
dst_links_state: array_class.LinksState,
dst_joints_state: array_class.JointsState,
dst_geoms_state: array_class.GeomsState,
):
# Copy outputs of [kernel_update_cartesian_space] among [dofs, links, joints, geoms] states. This is used to restore
# the outputs that were overwritten if we disabled mujoco compatibility for backward pass.

# dofs state
for i_d, i_b in ti.ndrange(src_dofs_state.pos.shape[0], src_dofs_state.pos.shape[1]):
# pos, cdof_ang, cdof_vel, cdofvel_ang, cdofvel_vel, cdofd_ang, cdofd_vel
dst_dofs_state.pos[i_d, i_b] = src_dofs_state.pos[i_d, i_b]
dst_dofs_state.cdof_ang[i_d, i_b] = src_dofs_state.cdof_ang[i_d, i_b]
dst_dofs_state.cdof_vel[i_d, i_b] = src_dofs_state.cdof_vel[i_d, i_b]
dst_dofs_state.cdofvel_ang[i_d, i_b] = src_dofs_state.cdofvel_ang[i_d, i_b]
dst_dofs_state.cdofvel_vel[i_d, i_b] = src_dofs_state.cdofvel_vel[i_d, i_b]
dst_dofs_state.cdofd_ang[i_d, i_b] = src_dofs_state.cdofd_ang[i_d, i_b]
dst_dofs_state.cdofd_vel[i_d, i_b] = src_dofs_state.cdofd_vel[i_d, i_b]

# links state
for i_l, i_b in ti.ndrange(src_links_state.pos.shape[0], src_links_state.pos.shape[1]):
# pos, quat, root_COM, mass_sum, i_pos, i_quat, cinr_inertial, cinr_pos, cinr_quat, cinr_mass, j_pos, j_quat,
# cd_vel, cd_ang
dst_links_state.pos[i_l, i_b] = src_links_state.pos[i_l, i_b]
dst_links_state.quat[i_l, i_b] = src_links_state.quat[i_l, i_b]
dst_links_state.root_COM[i_l, i_b] = src_links_state.root_COM[i_l, i_b]
dst_links_state.mass_sum[i_l, i_b] = src_links_state.mass_sum[i_l, i_b]
dst_links_state.i_pos[i_l, i_b] = src_links_state.i_pos[i_l, i_b]
dst_links_state.i_quat[i_l, i_b] = src_links_state.i_quat[i_l, i_b]
dst_links_state.cinr_inertial[i_l, i_b] = src_links_state.cinr_inertial[i_l, i_b]
dst_links_state.cinr_pos[i_l, i_b] = src_links_state.cinr_pos[i_l, i_b]
dst_links_state.cinr_quat[i_l, i_b] = src_links_state.cinr_quat[i_l, i_b]
dst_links_state.cinr_mass[i_l, i_b] = src_links_state.cinr_mass[i_l, i_b]
dst_links_state.j_pos[i_l, i_b] = src_links_state.j_pos[i_l, i_b]
dst_links_state.j_quat[i_l, i_b] = src_links_state.j_quat[i_l, i_b]
dst_links_state.cd_vel[i_l, i_b] = src_links_state.cd_vel[i_l, i_b]
dst_links_state.cd_ang[i_l, i_b] = src_links_state.cd_ang[i_l, i_b]

# joints state
for i_j, i_b in ti.ndrange(src_joints_state.xanchor.shape[0], src_joints_state.xanchor.shape[1]):
# xanchor, xaxis
dst_joints_state.xanchor[i_j, i_b] = src_joints_state.xanchor[i_j, i_b]
dst_joints_state.xaxis[i_j, i_b] = src_joints_state.xaxis[i_j, i_b]

# geoms state
for i_g, i_b in ti.ndrange(src_geoms_state.pos.shape[0], src_geoms_state.pos.shape[1]):
# pos, quat, verts_updated
dst_geoms_state.pos[i_g, i_b] = src_geoms_state.pos[i_g, i_b]
dst_geoms_state.quat[i_g, i_b] = src_geoms_state.quat[i_g, i_b]
dst_geoms_state.verts_updated[i_g, i_b] = src_geoms_state.verts_updated[i_g, i_b]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we should avoid this kind of tiny kernels doing nothing, because overhead will dominates by a very large margin. I will try to see if we can leverage zero-copy for this. By ideally, it should be moved as part of a larger kernel.

Comment on lines 7135 to 7154
@ti.kernel(fastcache=gs.use_fastcache)
def kernel_save_adjoint_cache(
f: ti.int32,
dofs_state: array_class.DofsState,
rigid_global_info: array_class.RigidGlobalInfo,
rigid_adjoint_cache: array_class.RigidAdjointCache,
static_rigid_sim_config: ti.template(),
):
n_dofs = dofs_state.vel.shape[0]
n_qs = rigid_global_info.qpos.shape[0]
_B = dofs_state.vel.shape[1]

ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_d, i_b in ti.ndrange(n_dofs, _B):
rigid_adjoint_cache.dofs_vel[f, i_d, i_b] = dofs_state.vel[i_d, i_b]
rigid_adjoint_cache.dofs_acc[f, i_d, i_b] = dofs_state.acc[i_d, i_b]

ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_q, i_b in ti.ndrange(n_qs, _B):
rigid_adjoint_cache.qpos[f, i_q, i_b] = rigid_global_info.qpos[i_q, i_b]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same. This kernel is too small. Ideally it should be a ti.func that is part of a bigger kernel.

Comment on lines 7107 to 7117
@ti.kernel(fastcache=gs.use_fastcache)
def kernel_copy_next_to_curr(
dofs_state: array_class.DofsState,
rigid_global_info: array_class.RigidGlobalInfo,
static_rigid_sim_config: ti.template(),
):
func_copy_next_to_curr(
dofs_state=dofs_state,
rigid_global_info=rigid_global_info,
static_rigid_sim_config=static_rigid_sim_config,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same. This kernel is too small. Ideally it should be a ti.func that is part of a bigger kernel.

Comment on lines 6804 to 6806
for i_l, i_b in ti.ndrange(links_info.root_idx.shape[0], links_state.pos.shape[1]):
links_state.cfrc_coupling_ang[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3)
links_state.cfrc_coupling_vel[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3)
Copy link
Collaborator

@duburcqa duburcqa Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    for I in ti.grouped(ti.ndrange(*links_state.cfrc_coupling_ang.shape)):
        links_state.cfrc_coupling_ang[I] = ti.Vector.zero(gs.ti_float, 3)
        links_state.cfrc_coupling_vel[I] = ti.Vector.zero(gs.ti_float, 3)

I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l

i_r = links_info.root_idx[I_l]
if i_l == i_r:
links_state.root_COM[i_l, i_b] = links_state.root_COM[i_l, i_b] / links_state.mass_sum[i_l, i_b]
if i_l == i_r and links_state.mass_sum[i_l, i_b] > 0.0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This condition is already enforced by design. Why are you adding it? Did you observed any issue in practice?

def max_n_qs_per_link(self):
if self.is_built:
return self._max_n_qs_per_link
return max([link.n_qs for link in self.links]) if len(self.links) > 0 else 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max(link.n_qs for link in self.links) if self.links else 0

def max_n_dofs_per_entity(self):
if self.is_built:
return self._max_n_dofs_per_entity
return max([entity.n_dofs for entity in self._entities]) if len(self._entities) > 0 else 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max(entity.n_dofs for entity in self._entities) if self._entities else 0

Comment on lines 2953 to 2959
return max([link.n_dofs for link in self.links]) if len(self.links) > 0 else 0

@property
def max_n_dofs_per_joint(self):
if self.is_built:
return self._max_n_dofs_per_joint
return max([joint.n_dofs for joint in self.joints]) if len(self.joints) > 0 else 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same.

Comment on lines 2853 to 2869
@property
def max_n_joints_per_link(self):
if self.is_built:
return self._max_n_joints_per_link
return max([len(link.joints) for link in self.links]) if len(self.links) > 0 else 0

@property
def n_geoms(self):
if self.is_built:
return self._n_geoms
return len(self.geoms)

@property
def max_n_geoms_per_entity(self):
if self.is_built:
return self._max_n_geoms_per_entity
return max([entity.n_geoms for entity in self._entities]) if len(self._entities) > 0 else 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max(len(link.joints) for link in self.links) if self.links else 0

Comment on lines 2841 to 2845
@property
def max_n_links_per_entity(self):
if self.is_built:
return self._max_n_links_per_entity
return max([len(entity.links) for entity in self._entities]) if len(self._entities) > 0 else 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max(len(entity.links) for entity in self._entities) if self._entities else 0

Comment on lines 2263 to 2274
def set_dofs_velocity_grad(self, dofs_idx, envs_idx, unsafe, velocity_grad):
velocity_grad_, dofs_idx, envs_idx = self._sanitize_1D_io_variables(
velocity_grad, dofs_idx, self.n_dofs, envs_idx, skip_allocation=True, unsafe=unsafe
)
if self.n_envs == 0:
velocity_grad_ = velocity_grad_.unsqueeze(0)
kernel_set_dofs_velocity_grad(
velocity_grad_, dofs_idx, envs_idx, self.dofs_state, self._static_rigid_sim_config
)
if self.n_envs == 0:
velocity_grad_ = velocity_grad_.squeeze(0)
velocity_grad.data = self._sanitize_1D_io_variables_grad(velocity_grad_, velocity_grad)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems more complex than necessary. Why cannot you assign velocity_grad.data directly and why do you need to validate the shape?

@github-actions
Copy link

⚠️ Benchmark Regression Detected
➡️ Report

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants