Skip to content

Commit f62cad9

Browse files
Fix sharding of segment position to match id in ring attention. (#2349)
1 parent 4ff3eed commit f62cad9

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

transformer_engine/jax/cpp_extensions/attention.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1784,6 +1784,9 @@ def partition(config, mesh, arg_infos, result_infos):
17841784
)
17851785
arg_shardings = [arg_i.sharding for arg_i in arg_infos]
17861786
arg_shardings[4] = seed_sharding
1787+
# Ensure segment_pos gets same sharding as ID.
1788+
arg_shardings[-1] = arg_shardings[-3]
1789+
arg_shardings[-2] = arg_shardings[-4]
17871790
arg_shardings = tuple(arg_shardings)
17881791
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
17891792

@@ -1991,7 +1994,13 @@ def partition(config, mesh, arg_infos, result_infos):
19911994
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
19921995
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
19931996
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
1994-
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
1997+
1998+
arg_shardings = [arg_i.sharding for arg_i in arg_infos]
1999+
# Ensure segment_pos gets same sharding as ID.
2000+
arg_shardings[-1] = arg_shardings[-3]
2001+
arg_shardings[-2] = arg_shardings[-4]
2002+
arg_shardings = tuple(arg_shardings)
2003+
19952004
out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)
19962005

19972006
helper = _FusedAttnCPWithP2PHelper(mesh, config)
@@ -2265,6 +2274,9 @@ def partition(config, mesh, arg_infos, result_infos):
22652274
)
22662275
arg_shardings = [arg_i.sharding for arg_i in arg_infos]
22672276
arg_shardings[4] = seed_sharding
2277+
# Ensure segment_pos gets same sharding as ID.
2278+
arg_shardings[-1] = arg_shardings[-3]
2279+
arg_shardings[-2] = arg_shardings[-4]
22682280
arg_shardings = tuple(arg_shardings)
22692281
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
22702282

@@ -2403,7 +2415,11 @@ def partition(config, mesh, arg_infos, result_infos):
24032415
if not is_context_parallel:
24042416
return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)
24052417

2406-
arg_shardings = tuple(arg.sharding for arg in arg_infos)
2418+
arg_shardings = [arg_i.sharding for arg_i in arg_infos]
2419+
# Ensure segment_pos gets same sharding as ID.
2420+
arg_shardings[-1] = arg_shardings[-3]
2421+
arg_shardings[-2] = arg_shardings[-4]
2422+
arg_shardings = tuple(arg_shardings)
24072423
# dq, dk, dv, dbias sharding = q, k, v, bias sharding
24082424
out_shardings = tuple(arg.sharding for arg in arg_infos[:4])
24092425

0 commit comments

Comments
 (0)