@@ -921,16 +921,8 @@ def substep(self, f):
921921
922922 self ._links_state_cache .clear ()
923923
924- if f == 0 and self ._requires_grad :
925- kernel_save_adjoint_cache (
926- f = f ,
927- dofs_state = self .dofs_state ,
928- rigid_global_info = self ._rigid_global_info ,
929- rigid_adjoint_cache = self ._rigid_adjoint_cache ,
930- static_rigid_sim_config = self ._static_rigid_sim_config ,
931- )
932-
933924 kernel_step_1 (
925+ f = f ,
934926 links_state = self .links_state ,
935927 links_info = self .links_info ,
936928 joints_state = self .joints_state ,
@@ -942,6 +934,7 @@ def substep(self, f):
942934 entities_state = self .entities_state ,
943935 entities_info = self .entities_info ,
944936 rigid_global_info = self ._rigid_global_info ,
937+ rigid_adjoint_cache = self ._rigid_adjoint_cache ,
945938 static_rigid_sim_config = self ._static_rigid_sim_config ,
946939 contact_island_state = self .constraint_solver .contact_island .contact_island_state ,
947940 is_backward = False ,
@@ -957,6 +950,7 @@ def substep(self, f):
957950 else :
958951 self ._func_constraint_force ()
959952 kernel_step_2 (
953+ f = f ,
960954 dofs_state = self .dofs_state ,
961955 dofs_info = self .dofs_info ,
962956 links_info = self .links_info ,
@@ -969,20 +963,12 @@ def substep(self, f):
969963 geoms_state = self .geoms_state ,
970964 collider_state = self .collider ._collider_state ,
971965 rigid_global_info = self ._rigid_global_info ,
966+ rigid_adjoint_cache = self ._rigid_adjoint_cache ,
972967 static_rigid_sim_config = self ._static_rigid_sim_config ,
973968 contact_island_state = self .constraint_solver .contact_island .contact_island_state ,
974969 is_backward = False ,
975970 )
976971
977- if self ._requires_grad :
978- kernel_save_adjoint_cache (
979- f = f + 1 ,
980- dofs_state = self .dofs_state ,
981- rigid_global_info = self ._rigid_global_info ,
982- rigid_adjoint_cache = self ._rigid_adjoint_cache ,
983- static_rigid_sim_config = self ._static_rigid_sim_config ,
984- )
985-
986972 def check_errno (self ):
987973 # Note that errno must be evaluated BEFORE match because otherwise it will be evaluated for each case...
988974 # See official documentation: https://docs.python.org/3.10/reference/compound_stmts.html#overview
@@ -1245,8 +1231,6 @@ def substep_pre_coupling(self, f):
12451231 self .substep (f )
12461232
12471233 def substep_pre_coupling_grad (self , f ):
1248- curr_dofs_acc = self ._rigid_adjoint_cache .dofs_acc .to_numpy ()[f ]
1249-
12501234 # Run forward substep again to restore this step's information, this is needed because we do not store info
12511235 # of every substep.
12521236 kernel_prepare_backward_substep (
@@ -1311,6 +1295,7 @@ def substep_pre_coupling_grad(self, f):
13111295 gs .raise_exception (f"Nan grad in qpos or dofs_vel found at step { self ._sim .cur_step_global } " )
13121296
13131297 kernel_step_2 .grad (
1298+ f = f ,
13141299 dofs_state = self .dofs_state ,
13151300 dofs_info = self .dofs_info ,
13161301 links_info = self .links_info ,
@@ -1323,21 +1308,34 @@ def substep_pre_coupling_grad(self, f):
13231308 geoms_state = self .geoms_state ,
13241309 collider_state = self .collider ._collider_state ,
13251310 rigid_global_info = self ._rigid_global_info ,
1311+ rigid_adjoint_cache = self ._rigid_adjoint_cache ,
13261312 static_rigid_sim_config = self ._static_rigid_sim_config ,
13271313 contact_island_state = self .constraint_solver .contact_island .contact_island_state ,
13281314 is_backward = True ,
13291315 )
13301316
1317+ # We cannot use [kernel_forward_dynamics.grad] because we read [dofs_state.acc] and overwrite it in the kernel,
1318+ # which is prohibited (https://docs.taichi-lang.org/docs/differentiable_programming#global-data-access-rules).
1319+ # In [kernel_forward_dynamics], we read [acc] in [func_update_acc] and overwrite it in [kernel_compute_qacc].
1320+ # As [kenrel_compute_qacc] is called at the end of [kernel_forward_dynamics], we first backpropagate through
1321+ # [kernel_compute_qacc] and then restore the original [acc] from the adjoint cache. This copy operation
1322+ # cannot be merged with [kernel_compute_qacc.grad] because .grad function itself is a standalone kernel.
1323+ # We could possibly merge this small kernel later if (1) .grad function is regarded as a function instead of a
1324+ # kernel, (2) we add another variable to store the new [acc] from [kernel_compute_qacc] and thus can avoid
1325+ # the data access violation. However, both of these require major changes.
13311326 kernel_compute_qacc .grad (
13321327 dofs_state = self .dofs_state ,
13331328 entities_info = self .entities_info ,
13341329 rigid_global_info = self ._rigid_global_info ,
13351330 static_rigid_sim_config = self ._static_rigid_sim_config ,
13361331 is_backward = True ,
13371332 )
1338-
1339- # Load the current dofs_acc from adjoint cache, as it was overwritten by [kernel_compute_qacc]
1340- self .dofs_state .acc .from_numpy (curr_dofs_acc )
1333+ kernel_copy_acc (
1334+ f = f ,
1335+ dofs_state = self .dofs_state ,
1336+ rigid_adjoint_cache = self ._rigid_adjoint_cache ,
1337+ static_rigid_sim_config = self ._static_rigid_sim_config ,
1338+ )
13411339
13421340 kernel_forward_dynamics_without_qacc .grad (
13431341 links_state = self .links_state ,
@@ -1386,6 +1384,7 @@ def substep_post_coupling(self, f):
13861384 is_backward = False ,
13871385 )
13881386 kernel_step_2 (
1387+ f = f ,
13891388 dofs_state = self .dofs_state ,
13901389 dofs_info = self .dofs_info ,
13911390 links_info = self .links_info ,
@@ -1398,18 +1397,11 @@ def substep_post_coupling(self, f):
13981397 geoms_state = self .geoms_state ,
13991398 collider_state = self .collider ._collider_state ,
14001399 rigid_global_info = self ._rigid_global_info ,
1400+ rigid_adjoint_cache = self ._rigid_adjoint_cache ,
14011401 static_rigid_sim_config = self ._static_rigid_sim_config ,
14021402 contact_island_state = self .constraint_solver .contact_island .contact_island_state ,
14031403 is_backward = False ,
14041404 )
1405- if self ._requires_grad :
1406- kernel_save_adjoint_cache (
1407- f = f + 1 ,
1408- dofs_state = self .dofs_state ,
1409- rigid_global_info = self ._rigid_global_info ,
1410- rigid_adjoint_cache = self ._rigid_adjoint_cache ,
1411- static_rigid_sim_config = self ._static_rigid_sim_config ,
1412- )
14131405 elif isinstance (self .sim .coupler , IPCCoupler ):
14141406 # For IPCCoupler, perform full rigid body computation in post-coupling phase
14151407 # This allows IPC to handle rigid bodies during the coupling phase
@@ -4574,6 +4566,7 @@ def func_update_cartesian_space(
45744566
45754567@ti .kernel (fastcache = gs .use_fastcache )
45764568def kernel_step_1 (
4569+ f : ti .int32 ,
45774570 links_state : array_class .LinksState ,
45784571 links_info : array_class .LinksInfo ,
45794572 joints_state : array_class .JointsState ,
@@ -4585,10 +4578,21 @@ def kernel_step_1(
45854578 entities_state : array_class .EntitiesState ,
45864579 entities_info : array_class .EntitiesInfo ,
45874580 rigid_global_info : array_class .RigidGlobalInfo ,
4581+ rigid_adjoint_cache : array_class .RigidAdjointCache ,
45884582 static_rigid_sim_config : ti .template (),
45894583 contact_island_state : array_class .ContactIslandState ,
45904584 is_backward : ti .template (),
45914585):
4586+ if ti .static (static_rigid_sim_config .requires_grad ):
4587+ if f == 0 :
4588+ func_save_adjoint_cache (
4589+ f = f ,
4590+ dofs_state = dofs_state ,
4591+ rigid_global_info = rigid_global_info ,
4592+ rigid_adjoint_cache = rigid_adjoint_cache ,
4593+ static_rigid_sim_config = static_rigid_sim_config ,
4594+ )
4595+
45924596 if ti .static (static_rigid_sim_config .enable_mujoco_compatibility ):
45934597 _B = links_state .pos .shape [1 ]
45944598 ti .loop_config (serialize = static_rigid_sim_config .para_level < gs .PARA_LEVEL .ALL )
@@ -4700,6 +4704,7 @@ def func_implicit_damping(
47004704
47014705@ti .kernel (fastcache = gs .use_fastcache )
47024706def kernel_step_2 (
4707+ f : ti .int32 ,
47034708 dofs_state : array_class .DofsState ,
47044709 dofs_info : array_class .DofsInfo ,
47054710 links_info : array_class .LinksInfo ,
@@ -4712,6 +4717,7 @@ def kernel_step_2(
47124717 geoms_state : array_class .GeomsState ,
47134718 collider_state : array_class .ColliderState ,
47144719 rigid_global_info : array_class .RigidGlobalInfo ,
4720+ rigid_adjoint_cache : array_class .RigidAdjointCache ,
47154721 static_rigid_sim_config : ti .template (),
47164722 contact_island_state : array_class .ContactIslandState ,
47174723 is_backward : ti .template (),
@@ -4799,6 +4805,15 @@ def kernel_step_2(
47994805 is_backward = is_backward ,
48004806 )
48014807
4808+ if ti .static (static_rigid_sim_config .requires_grad ):
4809+ func_save_adjoint_cache (
4810+ f = f + 1 ,
4811+ dofs_state = dofs_state ,
4812+ rigid_global_info = rigid_global_info ,
4813+ rigid_adjoint_cache = rigid_adjoint_cache ,
4814+ static_rigid_sim_config = static_rigid_sim_config ,
4815+ )
4816+
48024817
48034818@ti .kernel (fastcache = gs .use_fastcache )
48044819def kernel_forward_kinematics_links_geoms (
@@ -6753,19 +6768,6 @@ def func_integrate(
67536768 )
67546769
67556770
6756- @ti .kernel (fastcache = gs .use_fastcache )
6757- def kernel_copy_next_to_curr (
6758- dofs_state : array_class .DofsState ,
6759- rigid_global_info : array_class .RigidGlobalInfo ,
6760- static_rigid_sim_config : ti .template (),
6761- ):
6762- func_copy_next_to_curr (
6763- dofs_state = dofs_state ,
6764- rigid_global_info = rigid_global_info ,
6765- static_rigid_sim_config = static_rigid_sim_config ,
6766- )
6767-
6768-
67696771@ti .func
67706772def func_copy_next_to_curr (
67716773 dofs_state : array_class .DofsState ,
@@ -6806,8 +6808,8 @@ def func_copy_next_to_curr_grad(
68066808 rigid_global_info .qpos [i_q , i_b ] = rigid_adjoint_cache .qpos [f , i_q , i_b ]
68076809
68086810
6809- @ti .kernel ( fastcache = gs . use_fastcache )
6810- def kernel_save_adjoint_cache (
6811+ @ti .func
6812+ def func_save_adjoint_cache (
68116813 f : ti .int32 ,
68126814 dofs_state : array_class .DofsState ,
68136815 rigid_global_info : array_class .RigidGlobalInfo ,
@@ -7048,28 +7050,18 @@ def func_copy_cartesian_space(
70487050
70497051
70507052@ti .kernel (fastcache = gs .use_fastcache )
7051- def kernel_copy_cartesian_space (
7052- src_dofs_state : array_class .DofsState ,
7053- src_links_state : array_class .LinksState ,
7054- src_joints_state : array_class .JointsState ,
7055- src_geoms_state : array_class .GeomsState ,
7056- dst_dofs_state : array_class .DofsState ,
7057- dst_links_state : array_class .LinksState ,
7058- dst_joints_state : array_class .JointsState ,
7059- dst_geoms_state : array_class .GeomsState ,
7053+ def kernel_copy_acc (
7054+ f : ti .int32 ,
7055+ dofs_state : array_class .DofsState ,
7056+ rigid_adjoint_cache : array_class .RigidAdjointCache ,
70607057 static_rigid_sim_config : ti .template (),
70617058):
7062- func_copy_cartesian_space (
7063- src_dofs_state = src_dofs_state ,
7064- src_links_state = src_links_state ,
7065- src_joints_state = src_joints_state ,
7066- src_geoms_state = src_geoms_state ,
7067- dst_dofs_state = dst_dofs_state ,
7068- dst_links_state = dst_links_state ,
7069- dst_joints_state = dst_joints_state ,
7070- dst_geoms_state = dst_geoms_state ,
7071- static_rigid_sim_config = static_rigid_sim_config ,
7072- )
7059+ n_dofs = dofs_state .vel .shape [0 ]
7060+ _B = dofs_state .vel .shape [1 ]
7061+
7062+ ti .loop_config (serialize = static_rigid_sim_config .para_level < gs .PARA_LEVEL .ALL )
7063+ for i_d , i_b in ti .ndrange (n_dofs , _B ):
7064+ dofs_state .acc [i_d , i_b ] = rigid_adjoint_cache .dofs_acc [f , i_d , i_b ]
70737065
70747066
70757067@ti .func
0 commit comments