@@ -1784,6 +1784,9 @@ def partition(config, mesh, arg_infos, result_infos):
17841784 )
17851785 arg_shardings = [arg_i .sharding for arg_i in arg_infos ]
17861786 arg_shardings [4 ] = seed_sharding
1787+ # Ensure segment_pos gets same sharding as ID.
1788+ arg_shardings [- 1 ] = arg_shardings [- 3 ]
1789+ arg_shardings [- 2 ] = arg_shardings [- 4 ]
17871790 arg_shardings = tuple (arg_shardings )
17881791 out_shardings = (out_sharding , softmax_aux_sharding , rng_state_sharding )
17891792
@@ -1991,7 +1994,13 @@ def partition(config, mesh, arg_infos, result_infos):
19911994 dk_sharding = NamedSharding (mesh , PartitionSpec (* k_spec ))
19921995 dv_sharding = NamedSharding (mesh , PartitionSpec (* v_spec ))
19931996 dbias_sharding = NamedSharding (mesh , PartitionSpec (* bias_spec ))
1994- arg_shardings = tuple (arg_i .sharding for arg_i in arg_infos )
1997+
1998+ arg_shardings = [arg_i .sharding for arg_i in arg_infos ]
1999+ # Ensure segment_pos gets same sharding as ID.
2000+ arg_shardings [- 1 ] = arg_shardings [- 3 ]
2001+ arg_shardings [- 2 ] = arg_shardings [- 4 ]
2002+ arg_shardings = tuple (arg_shardings )
2003+
19952004 out_shardings = (dq_sharding , dk_sharding , dv_sharding , dbias_sharding )
19962005
19972006 helper = _FusedAttnCPWithP2PHelper (mesh , config )
@@ -2265,6 +2274,9 @@ def partition(config, mesh, arg_infos, result_infos):
22652274 )
22662275 arg_shardings = [arg_i .sharding for arg_i in arg_infos ]
22672276 arg_shardings [4 ] = seed_sharding
2277+ # Ensure segment_pos gets same sharding as ID.
2278+ arg_shardings [- 1 ] = arg_shardings [- 3 ]
2279+ arg_shardings [- 2 ] = arg_shardings [- 4 ]
22682280 arg_shardings = tuple (arg_shardings )
22692281 out_shardings = (out_sharding , softmax_aux_sharding , rng_state_sharding )
22702282
@@ -2403,7 +2415,11 @@ def partition(config, mesh, arg_infos, result_infos):
24032415 if not is_context_parallel :
24042416 return FusedAttnBwdPrimitive .partition (config , mesh , arg_infos , result_infos )
24052417
2406- arg_shardings = tuple (arg .sharding for arg in arg_infos )
2418+ arg_shardings = [arg_i .sharding for arg_i in arg_infos ]
2419+ # Ensure segment_pos gets same sharding as ID.
2420+ arg_shardings [- 1 ] = arg_shardings [- 3 ]
2421+ arg_shardings [- 2 ] = arg_shardings [- 4 ]
2422+ arg_shardings = tuple (arg_shardings )
24072423 # dq, dk, dv, dbias sharding = q, k, v, bias sharding
24082424 out_shardings = tuple (arg .sharding for arg in arg_infos [:4 ])
24092425
0 commit comments