-
Notifications
You must be signed in to change notification settings - Fork 2.6k
[FEATURE] Differentiable forward dynamics for rigid body sim. #1808
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
9f033bb to
2a8c057
Compare
|
Thresholds: runtime ≤ −10%, compile ≥ +10% Runtime FPS
Compile Time
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is noslip feature supported when enabling gradient computation? If not, you should raise an exception at init.
| constraint_state.Mgrad, | ||
| constraint_state.Mgrad, # this will not be used anyway because is_backward is False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After checking, it seems that another variable is specified when is_backward=True, so here if I understand correctly you are setting it to anything because it will not be used in practice. This does work in practice but I don't like it much... What about defining some extra free variable PLACEHOLDER in array_class (0D taichi tensor of type array_class.V) that we could use everywhere an argument is not used? This would clarify the intend and avoid any mistake because you cannot do much with such tensor.
|
Thresholds: runtime ≤ −10%, compile ≥ +10% Runtime FPS
Compile Time
|
|
Hey! I was just wondering if this PR is close to being merged soon, since the Heterogeneous Simulation PR is dependent on this. Thanks! PR: #1589 |
Hey, sorry for being late. This PR needs some polishing to pass some benchmark tests, and I'm working on it. Sorry for being late, I'll try to wrap up as soon as possible. |
Thanks so much! Yeah sorry for being pushy but do you have a rough timeline for this because I'm looking to train with the heterogeneous simulation and might reprioritize somethings. |
It's hard to say, but I think at least a week is needed (because code review is needed again for merging). I'll let you know if I have a better estimate. |
|
following up on this - any updates on the eta? also looking to use heterogeneous simulations (#1589) which depends on this |
Hi, this PR will be merged tomorrow if code review goes well. Based on the review, it could be delayed a little, but I'll try to finish up until this weekend. |
|
|
|
|
| def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, unsafe=False, update_tgt=True): | ||
| """ | ||
| Set quaternion of the entity's base link. | ||
| Parameters | ||
| ---------- | ||
| quat : array_like | ||
| The quaternion to set. | ||
| relative : bool, optional | ||
| Whether the quaternion to set is absolute or relative to the initial (not current!) quaternion. Defaults to | ||
| False. | ||
| zero_velocity : bool, optional | ||
| Whether to zero the velocity of all the entity's dofs. Defaults to True. This is a safety measure after a | ||
| sudden change in entity pose. | ||
| envs_idx : None | array_like, optional | ||
| The indices of the environments. If None, all environments will be considered. Defaults to None. | ||
| """ | ||
| # Save in [tgt] for backward pass | ||
| if update_tgt: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update_tgt seems an internal feature. I would recommend calling self.update_tgt manually right before set_quat to avoid leaking this option to the user.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After checking twice, keep self.update_tgt call inside this function but rather use a global state variable self._update_tgt to track whether target must be updated or not, without exposing this option?
|
|
||
| @gs.assert_built | ||
| def set_quat_grad(self, envs_idx, relative, unsafe, quat_grad): | ||
| tmp_quat_grad = quat_grad.unsqueeze(-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't get it. Why don't you pass quat_grad.data directly?
| i_e = self.contact_island.entity_id[i_e_, i_b] | ||
| self._solver.mass_mat_mask[i_e_, i_b] = True | ||
| self._solver._func_solve_mass_batched(self.grad, self.Mgrad, i_b) | ||
| self._solver._func_solve_mass_batched(self.grad, self.Mgrad, None, i_b) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think passing none is supported.
| @@ -250,10 +264,60 @@ def build(self): | |||
| enable_joint_limit=getattr(self, "_enable_joint_limit", False), | |||
| box_box_detection=getattr(self, "_box_box_detection", True), | |||
| sparse_solve=getattr(self._options, "sparse_solve", False), | |||
| integrator=getattr(self, "_integrator", gs.integrator.implicitfast), | |||
| integrator=getattr(self, "_integrator", gs.integrator.approximate_implicitfast), | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Weird to include this "bug fix" in this PR but OK.
| if isinstance(self.sim.coupler, SAPCoupler): | ||
| gs.raise_exception("SAPCoupler is not supported yet when requires_grad is True.") | ||
|
|
||
| if isinstance(self.sim.coupler, IPCCoupler): | ||
| gs.raise_exception("IPCCoupler is not supported yet when requires_grad is True.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if isinstance(self.sim.coupler, (SAPCoupler, IPCCoupler)):
gs.raise_exception(f"{type(self.sim.coupler).__name__} is not supported yet when requires_grad is True.")
| # Add terms for static inner loops, use 0 if not requires_grad to avoid re-compilation | ||
| self._static_rigid_sim_config.max_n_links_per_entity = ( | ||
| getattr(self, "_max_n_links_per_entity", 0) if self._static_rigid_sim_config.requires_grad else 0 | ||
| ) | ||
| self._static_rigid_sim_config.max_n_joints_per_link = ( | ||
| getattr(self, "_max_n_joints_per_link", 0) if self._static_rigid_sim_config.requires_grad else 0 | ||
| ) | ||
| self._static_rigid_sim_config.max_n_dofs_per_joint = ( | ||
| getattr(self, "_max_n_dofs_per_joint", 0) if self._static_rigid_sim_config.requires_grad else 0 | ||
| ) | ||
| self._static_rigid_sim_config.max_n_qs_per_link = ( | ||
| getattr(self, "_max_n_qs_per_link", 0) if self._static_rigid_sim_config.requires_grad else 0 | ||
| ) | ||
| self._static_rigid_sim_config.max_n_dofs_per_entity = ( | ||
| getattr(self, "_max_n_dofs_per_entity", 0) if self._static_rigid_sim_config.requires_grad else 0 | ||
| ) | ||
| self._static_rigid_sim_config.max_n_dofs_per_link = ( | ||
| getattr(self, "_max_n_dofs_per_link", 0) if self._static_rigid_sim_config.requires_grad else 0 | ||
| ) | ||
| self._static_rigid_sim_config.max_n_geoms_per_entity = ( | ||
| getattr(self, "_max_n_geoms_per_entity", 0) if self._static_rigid_sim_config.requires_grad else 0 | ||
| ) | ||
| self._static_rigid_sim_config.max_n_awake_links = ( | ||
| getattr(self, "_n_links", 0) if self._static_rigid_sim_config.requires_grad else 0 | ||
| ) | ||
| self._static_rigid_sim_config.max_n_awake_entities = ( | ||
| getattr(self, "_n_entities", 0) if self._static_rigid_sim_config.requires_grad else 0 | ||
| ) | ||
| self._static_rigid_sim_config.max_n_awake_dofs = ( | ||
| getattr(self, "_n_dofs", 0) if self._static_rigid_sim_config.requires_grad else 0 | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stop using getattr, it is a bad practice.
| @@ -291,7 +361,7 @@ def build(self): | |||
| self._init_constraint_solver() | |||
|
|
|||
| self._init_invweight_and_meaninertia(force_update=False) | |||
| self._func_update_geoms(self._scene._envs_idx, force_update_fixed_geoms=True) | |||
| self._func_update_geoms(self._scene._envs_idx, force_update_fixed_geoms=True, is_backward=False) | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should just make is_backward=False the default value. It would be easier.
| curr_qpos = self._rigid_adjoint_cache.qpos.to_numpy()[f] | ||
| curr_dofs_vel = self._rigid_adjoint_cache.dofs_vel.to_numpy()[f] | ||
| curr_dofs_acc = self._rigid_adjoint_cache.dofs_acc.to_numpy()[f] | ||
| self._rigid_global_info.qpos.from_numpy(curr_qpos) | ||
| self.dofs_state.vel.from_numpy(curr_dofs_vel) | ||
| self.dofs_state.acc.from_numpy(curr_dofs_acc) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is too inefficient. We need to find a better way.
| qpos_grad = self._rigid_global_info.qpos.grad.to_numpy() | ||
| dofs_vel_grad = self.dofs_state.vel.grad.to_numpy() | ||
| if np.isnan(qpos_grad).sum() > 0 or np.isnan(dofs_vel_grad).sum() > 0: | ||
| gs.raise_exception(f"Nan grad in qpos or dofs_vel found at step {self._sim.cur_step_global}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same. very inefficient. We cannot afford it I think.
| kernel_copy_next_to_curr.grad( | ||
| dofs_state=self.dofs_state, | ||
| rigid_global_info=self._rigid_global_info, | ||
| static_rigid_sim_config=self._static_rigid_sim_config, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All these small kernels are dominated by the overhead. You should rather make all of this a gigantic kernel.
| qpos_grad = gs.zeros_like(state.qpos) | ||
| dofs_vel_grad = gs.zeros_like(state.dofs_vel) | ||
| links_pos_grad = gs.zeros_like(state.links_pos) | ||
| links_quat_grad = gs.zeros_like(state.links_quat) | ||
|
|
||
| if state.qpos.grad is not None: | ||
| qpos_grad = state.qpos.grad | ||
|
|
||
| if state.dofs_vel.grad is not None: | ||
| dofs_vel_grad = state.dofs_vel.grad | ||
|
|
||
| if state.links_pos.grad is not None: | ||
| links_pos_grad = state.links_pos.grad | ||
|
|
||
| if state.links_quat.grad is not None: | ||
| links_quat_grad = state.links_quat.grad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid skipping all these lines, you are diluting information.
|
|
| self._ckpt[ckpt_name]["qpos"] = self._rigid_adjoint_cache.qpos.to_numpy() | ||
| self._ckpt[ckpt_name]["dofs_vel"] = self._rigid_adjoint_cache.dofs_vel.to_numpy() | ||
| self._ckpt[ckpt_name]["dofs_acc"] = self._rigid_adjoint_cache.dofs_acc.to_numpy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is part of the hot path, avoid using to_numpy and prefer ti_to_numpy
| @@ -1468,6 +1806,15 @@ def _sanitize_2D_io_variables( | |||
| gs.raise_exception("Expecting 2D input tensor.") | |||
| return tensor, _inputs_idx, envs_idx | |||
|
|
|||
| def _sanitize_2D_io_variables_grad( | |||
| self, | |||
| grad_after_sanitization, | |||
| grad_before_sanitization, | |||
| ): | |||
| if grad_after_sanitization.shape != grad_before_sanitization.shape: | |||
| gs.raise_exception("Shape of grad_after_sanitization and grad_before_sanitization do not match.") | |||
| return grad_after_sanitization | |||
|
|
|||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not "sanitize" anything. It "validates".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Besides, these are static functions. So I would just move them outside this class (right before it, as free function).
| velocity_grad_, dofs_idx, envs_idx = self._sanitize_1D_io_variables( | ||
| velocity_grad, dofs_idx, self.n_dofs, envs_idx, skip_allocation=True, unsafe=unsafe | ||
| ) | ||
| if self.n_envs == 0: | ||
| velocity_grad_ = velocity_grad_.unsqueeze(0) | ||
| kernel_set_dofs_velocity_grad( | ||
| velocity_grad_, dofs_idx, envs_idx, self.dofs_state, self._static_rigid_sim_config | ||
| ) | ||
| if self.n_envs == 0: | ||
| velocity_grad_ = velocity_grad_.squeeze(0) | ||
| velocity_grad.data = self._sanitize_1D_io_variables_grad(velocity_grad_, velocity_grad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't you just pass velocity_grad.data directly to the kernel?
velocity_grad.data, dofs_idx, envs_idx = self._sanitize_1D_io_variables(
velocity_grad, dofs_idx, self.n_dofs, envs_idx, skip_allocation=True, unsafe=unsafe
)
velocity_grad_ = velocity_grad.data
if self.n_envs == 0:
velocity_grad_ = velocity_grad_.unsqueeze(0)
kernel_set_dofs_velocity_grad(
velocity_grad_, dofs_idx, envs_idx, self.dofs_state, self._static_rigid_sim_config
)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this validation really needed?
| rigid_global_info: array_class.RigidGlobalInfo, | ||
| static_rigid_sim_config: ti.template(), | ||
| ): | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to skip line here.
| @@ -6206,6 +7418,7 @@ def kernel_set_state( | |||
| ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) | |||
| for i_d, i_b_ in ti.ndrange(n_dofs, envs_idx.shape[0]): | |||
| dofs_state.vel[i_d, envs_idx[i_b_]] = dofs_vel[envs_idx[i_b_], i_d] | |||
| dofs_state.acc[i_d, envs_idx[i_b_]] = dofs_acc[envs_idx[i_b_], i_d] | |||
| dofs_state.ctrl_force[i_d, envs_idx[i_b_]] = gs.ti_float(0.0) | |||
| dofs_state.ctrl_mode[i_d, envs_idx[i_b_]] = gs.CTRL_MODE.FORCE | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should move adding acc to rigid state in another PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new unit test in this PR fails if we don't include it here. Can you reconsider it?
| @ti.kernel(fastcache=gs.use_fastcache) | ||
| def kernel_copy_cartesian_space( | ||
| src_dofs_state: array_class.DofsState, | ||
| src_links_state: array_class.LinksState, | ||
| src_joints_state: array_class.JointsState, | ||
| src_geoms_state: array_class.GeomsState, | ||
| dst_dofs_state: array_class.DofsState, | ||
| dst_links_state: array_class.LinksState, | ||
| dst_joints_state: array_class.JointsState, | ||
| dst_geoms_state: array_class.GeomsState, | ||
| ): | ||
| # Copy outputs of [kernel_update_cartesian_space] among [dofs, links, joints, geoms] states. This is used to restore | ||
| # the outputs that were overwritten if we disabled mujoco compatibility for backward pass. | ||
|
|
||
| # dofs state | ||
| for i_d, i_b in ti.ndrange(src_dofs_state.pos.shape[0], src_dofs_state.pos.shape[1]): | ||
| # pos, cdof_ang, cdof_vel, cdofvel_ang, cdofvel_vel, cdofd_ang, cdofd_vel | ||
| dst_dofs_state.pos[i_d, i_b] = src_dofs_state.pos[i_d, i_b] | ||
| dst_dofs_state.cdof_ang[i_d, i_b] = src_dofs_state.cdof_ang[i_d, i_b] | ||
| dst_dofs_state.cdof_vel[i_d, i_b] = src_dofs_state.cdof_vel[i_d, i_b] | ||
| dst_dofs_state.cdofvel_ang[i_d, i_b] = src_dofs_state.cdofvel_ang[i_d, i_b] | ||
| dst_dofs_state.cdofvel_vel[i_d, i_b] = src_dofs_state.cdofvel_vel[i_d, i_b] | ||
| dst_dofs_state.cdofd_ang[i_d, i_b] = src_dofs_state.cdofd_ang[i_d, i_b] | ||
| dst_dofs_state.cdofd_vel[i_d, i_b] = src_dofs_state.cdofd_vel[i_d, i_b] | ||
|
|
||
| # links state | ||
| for i_l, i_b in ti.ndrange(src_links_state.pos.shape[0], src_links_state.pos.shape[1]): | ||
| # pos, quat, root_COM, mass_sum, i_pos, i_quat, cinr_inertial, cinr_pos, cinr_quat, cinr_mass, j_pos, j_quat, | ||
| # cd_vel, cd_ang | ||
| dst_links_state.pos[i_l, i_b] = src_links_state.pos[i_l, i_b] | ||
| dst_links_state.quat[i_l, i_b] = src_links_state.quat[i_l, i_b] | ||
| dst_links_state.root_COM[i_l, i_b] = src_links_state.root_COM[i_l, i_b] | ||
| dst_links_state.mass_sum[i_l, i_b] = src_links_state.mass_sum[i_l, i_b] | ||
| dst_links_state.i_pos[i_l, i_b] = src_links_state.i_pos[i_l, i_b] | ||
| dst_links_state.i_quat[i_l, i_b] = src_links_state.i_quat[i_l, i_b] | ||
| dst_links_state.cinr_inertial[i_l, i_b] = src_links_state.cinr_inertial[i_l, i_b] | ||
| dst_links_state.cinr_pos[i_l, i_b] = src_links_state.cinr_pos[i_l, i_b] | ||
| dst_links_state.cinr_quat[i_l, i_b] = src_links_state.cinr_quat[i_l, i_b] | ||
| dst_links_state.cinr_mass[i_l, i_b] = src_links_state.cinr_mass[i_l, i_b] | ||
| dst_links_state.j_pos[i_l, i_b] = src_links_state.j_pos[i_l, i_b] | ||
| dst_links_state.j_quat[i_l, i_b] = src_links_state.j_quat[i_l, i_b] | ||
| dst_links_state.cd_vel[i_l, i_b] = src_links_state.cd_vel[i_l, i_b] | ||
| dst_links_state.cd_ang[i_l, i_b] = src_links_state.cd_ang[i_l, i_b] | ||
|
|
||
| # joints state | ||
| for i_j, i_b in ti.ndrange(src_joints_state.xanchor.shape[0], src_joints_state.xanchor.shape[1]): | ||
| # xanchor, xaxis | ||
| dst_joints_state.xanchor[i_j, i_b] = src_joints_state.xanchor[i_j, i_b] | ||
| dst_joints_state.xaxis[i_j, i_b] = src_joints_state.xaxis[i_j, i_b] | ||
|
|
||
| # geoms state | ||
| for i_g, i_b in ti.ndrange(src_geoms_state.pos.shape[0], src_geoms_state.pos.shape[1]): | ||
| # pos, quat, verts_updated | ||
| dst_geoms_state.pos[i_g, i_b] = src_geoms_state.pos[i_g, i_b] | ||
| dst_geoms_state.quat[i_g, i_b] = src_geoms_state.quat[i_g, i_b] | ||
| dst_geoms_state.verts_updated[i_g, i_b] = src_geoms_state.verts_updated[i_g, i_b] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally we should avoid this kind of tiny kernels doing nothing, because overhead will dominates by a very large margin. I will try to see if we can leverage zero-copy for this. By ideally, it should be moved as part of a larger kernel.
| @ti.kernel(fastcache=gs.use_fastcache) | ||
| def kernel_save_adjoint_cache( | ||
| f: ti.int32, | ||
| dofs_state: array_class.DofsState, | ||
| rigid_global_info: array_class.RigidGlobalInfo, | ||
| rigid_adjoint_cache: array_class.RigidAdjointCache, | ||
| static_rigid_sim_config: ti.template(), | ||
| ): | ||
| n_dofs = dofs_state.vel.shape[0] | ||
| n_qs = rigid_global_info.qpos.shape[0] | ||
| _B = dofs_state.vel.shape[1] | ||
|
|
||
| ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) | ||
| for i_d, i_b in ti.ndrange(n_dofs, _B): | ||
| rigid_adjoint_cache.dofs_vel[f, i_d, i_b] = dofs_state.vel[i_d, i_b] | ||
| rigid_adjoint_cache.dofs_acc[f, i_d, i_b] = dofs_state.acc[i_d, i_b] | ||
|
|
||
| ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) | ||
| for i_q, i_b in ti.ndrange(n_qs, _B): | ||
| rigid_adjoint_cache.qpos[f, i_q, i_b] = rigid_global_info.qpos[i_q, i_b] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same. This kernel is too small. Ideally it should be a ti.func that is part of a bigger kernel.
| @ti.kernel(fastcache=gs.use_fastcache) | ||
| def kernel_copy_next_to_curr( | ||
| dofs_state: array_class.DofsState, | ||
| rigid_global_info: array_class.RigidGlobalInfo, | ||
| static_rigid_sim_config: ti.template(), | ||
| ): | ||
| func_copy_next_to_curr( | ||
| dofs_state=dofs_state, | ||
| rigid_global_info=rigid_global_info, | ||
| static_rigid_sim_config=static_rigid_sim_config, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same. This kernel is too small. Ideally it should be a ti.func that is part of a bigger kernel.
| for i_l, i_b in ti.ndrange(links_info.root_idx.shape[0], links_state.pos.shape[1]): | ||
| links_state.cfrc_coupling_ang[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) | ||
| links_state.cfrc_coupling_vel[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for I in ti.grouped(ti.ndrange(*links_state.cfrc_coupling_ang.shape)):
links_state.cfrc_coupling_ang[I] = ti.Vector.zero(gs.ti_float, 3)
links_state.cfrc_coupling_vel[I] = ti.Vector.zero(gs.ti_float, 3)| I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l | ||
|
|
||
| i_r = links_info.root_idx[I_l] | ||
| if i_l == i_r: | ||
| links_state.root_COM[i_l, i_b] = links_state.root_COM[i_l, i_b] / links_state.mass_sum[i_l, i_b] | ||
| if i_l == i_r and links_state.mass_sum[i_l, i_b] > 0.0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This condition is already enforced by design. Why are you adding it? Did you observed any issue in practice?
| def max_n_qs_per_link(self): | ||
| if self.is_built: | ||
| return self._max_n_qs_per_link | ||
| return max([link.n_qs for link in self.links]) if len(self.links) > 0 else 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max(link.n_qs for link in self.links) if self.links else 0
| def max_n_dofs_per_entity(self): | ||
| if self.is_built: | ||
| return self._max_n_dofs_per_entity | ||
| return max([entity.n_dofs for entity in self._entities]) if len(self._entities) > 0 else 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max(entity.n_dofs for entity in self._entities) if self._entities else 0
| return max([link.n_dofs for link in self.links]) if len(self.links) > 0 else 0 | ||
|
|
||
| @property | ||
| def max_n_dofs_per_joint(self): | ||
| if self.is_built: | ||
| return self._max_n_dofs_per_joint | ||
| return max([joint.n_dofs for joint in self.joints]) if len(self.joints) > 0 else 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same.
| @property | ||
| def max_n_joints_per_link(self): | ||
| if self.is_built: | ||
| return self._max_n_joints_per_link | ||
| return max([len(link.joints) for link in self.links]) if len(self.links) > 0 else 0 | ||
|
|
||
| @property | ||
| def n_geoms(self): | ||
| if self.is_built: | ||
| return self._n_geoms | ||
| return len(self.geoms) | ||
|
|
||
| @property | ||
| def max_n_geoms_per_entity(self): | ||
| if self.is_built: | ||
| return self._max_n_geoms_per_entity | ||
| return max([entity.n_geoms for entity in self._entities]) if len(self._entities) > 0 else 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max(len(link.joints) for link in self.links) if self.links else 0
| @property | ||
| def max_n_links_per_entity(self): | ||
| if self.is_built: | ||
| return self._max_n_links_per_entity | ||
| return max([len(entity.links) for entity in self._entities]) if len(self._entities) > 0 else 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max(len(entity.links) for entity in self._entities) if self._entities else 0
| def set_dofs_velocity_grad(self, dofs_idx, envs_idx, unsafe, velocity_grad): | ||
| velocity_grad_, dofs_idx, envs_idx = self._sanitize_1D_io_variables( | ||
| velocity_grad, dofs_idx, self.n_dofs, envs_idx, skip_allocation=True, unsafe=unsafe | ||
| ) | ||
| if self.n_envs == 0: | ||
| velocity_grad_ = velocity_grad_.unsqueeze(0) | ||
| kernel_set_dofs_velocity_grad( | ||
| velocity_grad_, dofs_idx, envs_idx, self.dofs_state, self._static_rigid_sim_config | ||
| ) | ||
| if self.n_envs == 0: | ||
| velocity_grad_ = velocity_grad_.squeeze(0) | ||
| velocity_grad.data = self._sanitize_1D_io_variables_grad(velocity_grad_, velocity_grad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems more complex than necessary. Why cannot you assign velocity_grad.data directly and why do you need to validate the shape?
7fe798f to
05398a0
Compare
|
|
Description
This PR has following changes:
is_backward(static whenis_backward=True)We do not add additional dimension to the states in the rigid body simulation for the frames, because it incurs too much code change.
Related Issue
Resolves Genesis-Embodied-AI/Genesis#
Motivation and Context
This is the part of the process to make the rigid body simulation to be differentiable.
How Has This Been / Can This Be Tested?
There is a unit test to verify by solving an optimization problem
tests/test_grad.py::test_differentiable_rigid.Screenshots (if appropriate):
Checklist:
Submitting Code Changessection of CONTRIBUTING document.