Skip to content

Commit

Permalink
[Pallas/TPU] Fix bug with LocalMask grid shrinking
Browse files Browse the repository at this point in the history
LocalMasks can trigger shrinking of the MaskInfo arrays and of the iteration space.
As a consequence, it is important that in the kernel body we use the `global_kv_index`. This is the kv_index in the "global" space without any shrinking of the iteration space.

PiperOrigin-RevId: 655901432
  • Loading branch information
Google-ML-Automation authored and jax authors committed Jul 25, 2024
1 parent e14752c commit f15f971
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ def init():
m_scratch_ref[...] = jnp.full_like(m_scratch_ref, mask_value)
l_scratch_ref[...] = jnp.zeros_like(l_scratch_ref)

_, _, should_run, should_not_mask = _next_nonzero(
global_kv_index, _, should_run, should_not_mask = _next_nonzero(
h,
i,
j,
Expand Down Expand Up @@ -760,7 +760,11 @@ def body(kv_compute_index, _):
kv_segment_ids_ref,
attn_logits_soft_cap=attn_logits_soft_cap,
k_slice=slice_k,
k_offset=j * bkv + kv_compute_index * bkv_compute,
# When the iteration space is shrunk (for local attention for example),
# the kv_index program_id does not correspond to the actual coordinates
# of the KV data. Make sure to use the 'unshrunk' index (coming from the
# data_next array) when computing the mask.
k_offset=global_kv_index * bkv + kv_compute_index * bkv_compute,
bq=bq,
mask_function=mask_function,
)
Expand Down Expand Up @@ -1282,7 +1286,7 @@ def _flash_attention_dq_kernel(
def init():
dq_scratch_ref[...] = jnp.zeros_like(dq_scratch_ref)

_, _, should_run, should_not_mask = _next_nonzero(
global_kv_index, _, should_run, should_not_mask = _next_nonzero(
h, i, j, data_next_ref, block_mask_ref, mask_next_ref
)
@pl.when(should_run)
Expand All @@ -1308,7 +1312,11 @@ def run():
kv_segment_ids_ref,
attn_logits_soft_cap=attn_logits_soft_cap,
k_slice=pl.ds(0, bkv),
k_offset=j * bkv,
# When the iteration space is shrunk (for local attention for example),
# the kv_index program_id does not correspond to the actual coordinates
# of the KV data. Make sure to use the 'unshrunk' index (coming from the
# data_next array) when computing the mask.
k_offset=global_kv_index * bkv,
bq=bq,
mask_function=mask_function,
)
Expand Down
66 changes: 66 additions & 0 deletions tests/pallas/tpu_splash_attention_kernel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from jax._src import test_util as jtu
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel as splash
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask_info import process_mask
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -649,6 +650,71 @@ def bwd(
self._assert_allclose(dq, dq_ref, atol=2e-2, rtol=3e-2)
self._assert_allclose(dk, dk_ref, atol=2e-2, rtol=3e-2)

def test_grid_shrinking(self):
"""Make sure that grid shrinking does not change the attention output."""

class IdentityMask(mask_lib._ComputableMask):
"""Identity mask that is guaranteed to trigger grid shrinking."""

def __init__(
self,
shape: tuple[int, int],
shard_count: int = 1,
):
def identity_mask_function(q_ids, kv_ids):
return q_ids == kv_ids

super().__init__(
shape=shape,
mask_function=identity_mask_function,
shard_count=shard_count,
)

def __eq__(self, other: object):
if not isinstance(other, type(self)):
return NotImplemented

return self.shape == other.shape and np.array_equal(
self.q_sequence, other.q_sequence
)

def __hash__(self):
return hash((
type(self),
self.shape,
self.q_sequence.tobytes() if self.q_sequence is not None else None,
))

# Use a sequence length greater than the default block size to trigger
# the grid shrinking logic.
seq_len = 256
head_dim = 128
key = random.key(42)
k1, k2, k3 = random.split(key, 3)
q = random.uniform(k1, (1, seq_len, head_dim), dtype=jnp.float32)
k = random.uniform(k2, (seq_len, head_dim), dtype=jnp.float32)
v = random.uniform(k3, (seq_len, head_dim), dtype=jnp.float32)

identity_mask = mask_lib.MultiHeadMask([IdentityMask((seq_len, seq_len))])

process_mask_path = "jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask_info.process_mask"
process_mask_shrink = lambda *args, **kwargs: process_mask(
*args, **kwargs, shrink_grid=True
)
process_mask_no_shrink = lambda *args, **kwargs: process_mask(
*args, **kwargs, shrink_grid=False
)

with unittest.mock.patch(process_mask_path, process_mask_shrink):
shrink_out = splash.make_splash_mqa_single_device(identity_mask)(q, k, v)

with unittest.mock.patch(process_mask_path, process_mask_no_shrink):
no_shrink_out = splash.make_splash_mqa_single_device(identity_mask)(
q, k, v
)

np.testing.assert_array_equal(shrink_out, no_shrink_out)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit f15f971

Please sign in to comment.