Skip to content

Commit 7fe798f

Browse files
committed
removed kernel for saving adjoint cache
1 parent c4bca92 commit 7fe798f

File tree

2 files changed

+60
-66
lines changed

2 files changed

+60
-66
lines changed

genesis/engine/solvers/rigid/rigid_solver_decomp.py

Lines changed: 58 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -921,16 +921,8 @@ def substep(self, f):
921921

922922
self._links_state_cache.clear()
923923

924-
if f == 0 and self._requires_grad:
925-
kernel_save_adjoint_cache(
926-
f=f,
927-
dofs_state=self.dofs_state,
928-
rigid_global_info=self._rigid_global_info,
929-
rigid_adjoint_cache=self._rigid_adjoint_cache,
930-
static_rigid_sim_config=self._static_rigid_sim_config,
931-
)
932-
933924
kernel_step_1(
925+
f=f,
934926
links_state=self.links_state,
935927
links_info=self.links_info,
936928
joints_state=self.joints_state,
@@ -942,6 +934,7 @@ def substep(self, f):
942934
entities_state=self.entities_state,
943935
entities_info=self.entities_info,
944936
rigid_global_info=self._rigid_global_info,
937+
rigid_adjoint_cache=self._rigid_adjoint_cache,
945938
static_rigid_sim_config=self._static_rigid_sim_config,
946939
contact_island_state=self.constraint_solver.contact_island.contact_island_state,
947940
is_backward=False,
@@ -957,6 +950,7 @@ def substep(self, f):
957950
else:
958951
self._func_constraint_force()
959952
kernel_step_2(
953+
f=f,
960954
dofs_state=self.dofs_state,
961955
dofs_info=self.dofs_info,
962956
links_info=self.links_info,
@@ -969,20 +963,12 @@ def substep(self, f):
969963
geoms_state=self.geoms_state,
970964
collider_state=self.collider._collider_state,
971965
rigid_global_info=self._rigid_global_info,
966+
rigid_adjoint_cache=self._rigid_adjoint_cache,
972967
static_rigid_sim_config=self._static_rigid_sim_config,
973968
contact_island_state=self.constraint_solver.contact_island.contact_island_state,
974969
is_backward=False,
975970
)
976971

977-
if self._requires_grad:
978-
kernel_save_adjoint_cache(
979-
f=f + 1,
980-
dofs_state=self.dofs_state,
981-
rigid_global_info=self._rigid_global_info,
982-
rigid_adjoint_cache=self._rigid_adjoint_cache,
983-
static_rigid_sim_config=self._static_rigid_sim_config,
984-
)
985-
986972
def check_errno(self):
987973
# Note that errno must be evaluated BEFORE match because otherwise it will be evaluated for each case...
988974
# See official documentation: https://docs.python.org/3.10/reference/compound_stmts.html#overview
@@ -1245,8 +1231,6 @@ def substep_pre_coupling(self, f):
12451231
self.substep(f)
12461232

12471233
def substep_pre_coupling_grad(self, f):
1248-
curr_dofs_acc = self._rigid_adjoint_cache.dofs_acc.to_numpy()[f]
1249-
12501234
# Run forward substep again to restore this step's information, this is needed because we do not store info
12511235
# of every substep.
12521236
kernel_prepare_backward_substep(
@@ -1311,6 +1295,7 @@ def substep_pre_coupling_grad(self, f):
13111295
gs.raise_exception(f"Nan grad in qpos or dofs_vel found at step {self._sim.cur_step_global}")
13121296

13131297
kernel_step_2.grad(
1298+
f=f,
13141299
dofs_state=self.dofs_state,
13151300
dofs_info=self.dofs_info,
13161301
links_info=self.links_info,
@@ -1323,21 +1308,34 @@ def substep_pre_coupling_grad(self, f):
13231308
geoms_state=self.geoms_state,
13241309
collider_state=self.collider._collider_state,
13251310
rigid_global_info=self._rigid_global_info,
1311+
rigid_adjoint_cache=self._rigid_adjoint_cache,
13261312
static_rigid_sim_config=self._static_rigid_sim_config,
13271313
contact_island_state=self.constraint_solver.contact_island.contact_island_state,
13281314
is_backward=True,
13291315
)
13301316

1317+
# We cannot use [kernel_forward_dynamics.grad] because we read [dofs_state.acc] and overwrite it in the kernel,
1318+
# which is prohibited (https://docs.taichi-lang.org/docs/differentiable_programming#global-data-access-rules).
1319+
# In [kernel_forward_dynamics], we read [acc] in [func_update_acc] and overwrite it in [kernel_compute_qacc].
1320+
# As [kenrel_compute_qacc] is called at the end of [kernel_forward_dynamics], we first backpropagate through
1321+
# [kernel_compute_qacc] and then restore the original [acc] from the adjoint cache. This copy operation
1322+
# cannot be merged with [kernel_compute_qacc.grad] because .grad function itself is a standalone kernel.
1323+
# We could possibly merge this small kernel later if (1) .grad function is regarded as a function instead of a
1324+
# kernel, (2) we add another variable to store the new [acc] from [kernel_compute_qacc] and thus can avoid
1325+
# the data access violation. However, both of these require major changes.
13311326
kernel_compute_qacc.grad(
13321327
dofs_state=self.dofs_state,
13331328
entities_info=self.entities_info,
13341329
rigid_global_info=self._rigid_global_info,
13351330
static_rigid_sim_config=self._static_rigid_sim_config,
13361331
is_backward=True,
13371332
)
1338-
1339-
# Load the current dofs_acc from adjoint cache, as it was overwritten by [kernel_compute_qacc]
1340-
self.dofs_state.acc.from_numpy(curr_dofs_acc)
1333+
kernel_copy_acc(
1334+
f=f,
1335+
dofs_state=self.dofs_state,
1336+
rigid_adjoint_cache=self._rigid_adjoint_cache,
1337+
static_rigid_sim_config=self._static_rigid_sim_config,
1338+
)
13411339

13421340
kernel_forward_dynamics_without_qacc.grad(
13431341
links_state=self.links_state,
@@ -1386,6 +1384,7 @@ def substep_post_coupling(self, f):
13861384
is_backward=False,
13871385
)
13881386
kernel_step_2(
1387+
f=f,
13891388
dofs_state=self.dofs_state,
13901389
dofs_info=self.dofs_info,
13911390
links_info=self.links_info,
@@ -1398,18 +1397,11 @@ def substep_post_coupling(self, f):
13981397
geoms_state=self.geoms_state,
13991398
collider_state=self.collider._collider_state,
14001399
rigid_global_info=self._rigid_global_info,
1400+
rigid_adjoint_cache=self._rigid_adjoint_cache,
14011401
static_rigid_sim_config=self._static_rigid_sim_config,
14021402
contact_island_state=self.constraint_solver.contact_island.contact_island_state,
14031403
is_backward=False,
14041404
)
1405-
if self._requires_grad:
1406-
kernel_save_adjoint_cache(
1407-
f=f + 1,
1408-
dofs_state=self.dofs_state,
1409-
rigid_global_info=self._rigid_global_info,
1410-
rigid_adjoint_cache=self._rigid_adjoint_cache,
1411-
static_rigid_sim_config=self._static_rigid_sim_config,
1412-
)
14131405
elif isinstance(self.sim.coupler, IPCCoupler):
14141406
# For IPCCoupler, perform full rigid body computation in post-coupling phase
14151407
# This allows IPC to handle rigid bodies during the coupling phase
@@ -4574,6 +4566,7 @@ def func_update_cartesian_space(
45744566

45754567
@ti.kernel(fastcache=gs.use_fastcache)
45764568
def kernel_step_1(
4569+
f: ti.int32,
45774570
links_state: array_class.LinksState,
45784571
links_info: array_class.LinksInfo,
45794572
joints_state: array_class.JointsState,
@@ -4585,10 +4578,21 @@ def kernel_step_1(
45854578
entities_state: array_class.EntitiesState,
45864579
entities_info: array_class.EntitiesInfo,
45874580
rigid_global_info: array_class.RigidGlobalInfo,
4581+
rigid_adjoint_cache: array_class.RigidAdjointCache,
45884582
static_rigid_sim_config: ti.template(),
45894583
contact_island_state: array_class.ContactIslandState,
45904584
is_backward: ti.template(),
45914585
):
4586+
if ti.static(static_rigid_sim_config.requires_grad):
4587+
if f == 0:
4588+
func_save_adjoint_cache(
4589+
f=f,
4590+
dofs_state=dofs_state,
4591+
rigid_global_info=rigid_global_info,
4592+
rigid_adjoint_cache=rigid_adjoint_cache,
4593+
static_rigid_sim_config=static_rigid_sim_config,
4594+
)
4595+
45924596
if ti.static(static_rigid_sim_config.enable_mujoco_compatibility):
45934597
_B = links_state.pos.shape[1]
45944598
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
@@ -4700,6 +4704,7 @@ def func_implicit_damping(
47004704

47014705
@ti.kernel(fastcache=gs.use_fastcache)
47024706
def kernel_step_2(
4707+
f: ti.int32,
47034708
dofs_state: array_class.DofsState,
47044709
dofs_info: array_class.DofsInfo,
47054710
links_info: array_class.LinksInfo,
@@ -4712,6 +4717,7 @@ def kernel_step_2(
47124717
geoms_state: array_class.GeomsState,
47134718
collider_state: array_class.ColliderState,
47144719
rigid_global_info: array_class.RigidGlobalInfo,
4720+
rigid_adjoint_cache: array_class.RigidAdjointCache,
47154721
static_rigid_sim_config: ti.template(),
47164722
contact_island_state: array_class.ContactIslandState,
47174723
is_backward: ti.template(),
@@ -4799,6 +4805,15 @@ def kernel_step_2(
47994805
is_backward=is_backward,
48004806
)
48014807

4808+
if ti.static(static_rigid_sim_config.requires_grad):
4809+
func_save_adjoint_cache(
4810+
f=f + 1,
4811+
dofs_state=dofs_state,
4812+
rigid_global_info=rigid_global_info,
4813+
rigid_adjoint_cache=rigid_adjoint_cache,
4814+
static_rigid_sim_config=static_rigid_sim_config,
4815+
)
4816+
48024817

48034818
@ti.kernel(fastcache=gs.use_fastcache)
48044819
def kernel_forward_kinematics_links_geoms(
@@ -6753,19 +6768,6 @@ def func_integrate(
67536768
)
67546769

67556770

6756-
@ti.kernel(fastcache=gs.use_fastcache)
6757-
def kernel_copy_next_to_curr(
6758-
dofs_state: array_class.DofsState,
6759-
rigid_global_info: array_class.RigidGlobalInfo,
6760-
static_rigid_sim_config: ti.template(),
6761-
):
6762-
func_copy_next_to_curr(
6763-
dofs_state=dofs_state,
6764-
rigid_global_info=rigid_global_info,
6765-
static_rigid_sim_config=static_rigid_sim_config,
6766-
)
6767-
6768-
67696771
@ti.func
67706772
def func_copy_next_to_curr(
67716773
dofs_state: array_class.DofsState,
@@ -6806,8 +6808,8 @@ def func_copy_next_to_curr_grad(
68066808
rigid_global_info.qpos[i_q, i_b] = rigid_adjoint_cache.qpos[f, i_q, i_b]
68076809

68086810

6809-
@ti.kernel(fastcache=gs.use_fastcache)
6810-
def kernel_save_adjoint_cache(
6811+
@ti.func
6812+
def func_save_adjoint_cache(
68116813
f: ti.int32,
68126814
dofs_state: array_class.DofsState,
68136815
rigid_global_info: array_class.RigidGlobalInfo,
@@ -7048,28 +7050,18 @@ def func_copy_cartesian_space(
70487050

70497051

70507052
@ti.kernel(fastcache=gs.use_fastcache)
7051-
def kernel_copy_cartesian_space(
7052-
src_dofs_state: array_class.DofsState,
7053-
src_links_state: array_class.LinksState,
7054-
src_joints_state: array_class.JointsState,
7055-
src_geoms_state: array_class.GeomsState,
7056-
dst_dofs_state: array_class.DofsState,
7057-
dst_links_state: array_class.LinksState,
7058-
dst_joints_state: array_class.JointsState,
7059-
dst_geoms_state: array_class.GeomsState,
7053+
def kernel_copy_acc(
7054+
f: ti.int32,
7055+
dofs_state: array_class.DofsState,
7056+
rigid_adjoint_cache: array_class.RigidAdjointCache,
70607057
static_rigid_sim_config: ti.template(),
70617058
):
7062-
func_copy_cartesian_space(
7063-
src_dofs_state=src_dofs_state,
7064-
src_links_state=src_links_state,
7065-
src_joints_state=src_joints_state,
7066-
src_geoms_state=src_geoms_state,
7067-
dst_dofs_state=dst_dofs_state,
7068-
dst_links_state=dst_links_state,
7069-
dst_joints_state=dst_joints_state,
7070-
dst_geoms_state=dst_geoms_state,
7071-
static_rigid_sim_config=static_rigid_sim_config,
7072-
)
7059+
n_dofs = dofs_state.vel.shape[0]
7060+
_B = dofs_state.vel.shape[1]
7061+
7062+
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
7063+
for i_d, i_b in ti.ndrange(n_dofs, _B):
7064+
dofs_state.acc[i_d, i_b] = rigid_adjoint_cache.dofs_acc[f, i_d, i_b]
70737065

70747066

70757067
@ti.func

tests/test_grad.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ def constraint_solver_resolve():
193193
# Step once to compute constraint solver's inputs: [mass], [jac], [aref], [efc_D], [force]. We do not call the
194194
# entire scene.step() because it will overwrite the necessary information that we need to compute the gradients.
195195
kernel_step_1(
196+
f=0,
196197
links_state=rigid_solver.links_state,
197198
links_info=rigid_solver.links_info,
198199
joints_state=rigid_solver.joints_state,
@@ -204,6 +205,7 @@ def constraint_solver_resolve():
204205
entities_state=rigid_solver.entities_state,
205206
entities_info=rigid_solver.entities_info,
206207
rigid_global_info=rigid_solver._rigid_global_info,
208+
rigid_adjoint_cache=rigid_solver._rigid_adjoint_cache,
207209
static_rigid_sim_config=rigid_solver._static_rigid_sim_config,
208210
contact_island_state=constraint_solver.contact_island.contact_island_state,
209211
is_backward=False,

0 commit comments

Comments
 (0)