Skip to content

Commit

Permalink
[Pallas] Use scratch_shapes for scratch operands in flash attention k…
Browse files Browse the repository at this point in the history
…ernel.

PiperOrigin-RevId: 611884935
  • Loading branch information
bythew3i authored and jax authors committed Mar 2, 2024
1 parent 28f84eb commit 51a31e5
Showing 1 changed file with 39 additions and 39 deletions.
78 changes: 39 additions & 39 deletions jax/experimental/pallas/ops/tpu/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,11 +347,11 @@ def _flash_attention_kernel_single_batch(
q_segment_ids_tile_ref,
kv_segment_ids_tile_ref, # Input arrays
o_tile_ref, # Output arrays
l_ref,
m_ref,
m_scratch_ref,
l_scratch_ref,
acc_scratch_ref,
l_ref: Any | None = None,
m_ref: Any | None = None,
*,
causal,
sm_scale,
Expand Down Expand Up @@ -496,9 +496,6 @@ def _flash_attention_kernel_single_batch_single_step(
q_segment_ids_tile_ref,
kv_segment_ids_tile_ref, # Input arrays
o_tile_ref, # Output arrays
m_scratch_ref,
l_scratch_ref,
acc_scratch_ref,
l_ref: Any | None = None,
m_ref: Any | None = None,
*,
Expand All @@ -511,8 +508,6 @@ def _flash_attention_kernel_single_batch_single_step(
block_k_major = k_tile_ref.shape[2]
block_q = q_tile_ref.shape[2]

scratch_refs = (m_scratch_ref, l_scratch_ref, acc_scratch_ref)
assert all(ref is None for ref in scratch_refs)
assert kv_seq_len == block_k_major == block_k

q = q_tile_ref[batch_idx] # [block_q, head_dim]
Expand Down Expand Up @@ -656,19 +651,12 @@ def lm_index_map(batch_index, head_index, q_seq_index, _):
out_specs = [pl.BlockSpec(o_index_map, (block_b, 1, block_q, head_dim))]

if block_k != kv_seq_len:
scratch_shape = functools.partial(jax.ShapeDtypeStruct, dtype=jnp.float32)
m_scratch = scratch_shape((block_b, 1, block_q, MIN_BLOCK_SIZE))
l_scratch = scratch_shape((block_b, 1, block_q, MIN_BLOCK_SIZE))
acc_scratch = scratch_shape((block_b, 1, block_q, head_dim))
out_shape += [m_scratch, l_scratch, acc_scratch]
out_specs += [
pl.BlockSpec(lambda *_: (0, 0, 0, 0), m_scratch.shape),
pl.BlockSpec(lambda *_: (0, 0, 0, 0), l_scratch.shape),
pl.BlockSpec(lambda *_: (0, 0, 0, 0), acc_scratch.shape),
]
m_scratch = pltpu.VMEM((block_b, 1, block_q, MIN_BLOCK_SIZE), jnp.float32)
l_scratch = pltpu.VMEM((block_b, 1, block_q, MIN_BLOCK_SIZE), jnp.float32)
acc_scratch = pltpu.VMEM((block_b, 1, block_q, head_dim), jnp.float32)
scratch_shapes = [m_scratch, l_scratch, acc_scratch]
else:
out_shape += [None, None, None]
out_specs += [None, None, None]
scratch_shapes = []

if save_residuals:
out_specs = [
Expand All @@ -683,6 +671,9 @@ def lm_index_map(batch_index, head_index, q_seq_index, _):
(batch_size, num_heads, q_seq_len, MIN_BLOCK_SIZE), dtype=jnp.float32
)
out_shape = (*out_shape, l, m)
else:
out_specs = [*out_specs, None, None]
out_shape = (*out_shape, None, None)

ab_block_spec = (
pl.BlockSpec(ab_index_map, (block_b, 1, block_q, block_k_major))
Expand Down Expand Up @@ -745,10 +736,14 @@ def kv_segment_ids_index_map(

o, *aux = pl.pallas_call(
kernel,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
grid=grid,
in_specs=in_specs,
out_specs=out_specs,
scratch_shapes=scratch_shapes,
),
out_shape=out_shape,
in_specs=in_specs,
out_specs=out_specs,
grid=grid,
debug=debug,
mosaic_params=dict(
dimension_semantics=("parallel", "parallel", "parallel", "arbitrary")
Expand Down Expand Up @@ -1070,17 +1065,15 @@ def kv_segment_ids_index_map(batch_index, head_index, kv_seq_index, _):
k.dtype),
jax.ShapeDtypeStruct((batch_size, num_heads, kv_seq_len, head_dim),
v.dtype),
jax.ShapeDtypeStruct((block_k_major, head_dim), jnp.float32),
jax.ShapeDtypeStruct((block_k_major, head_dim), jnp.float32),
]
def dkv_index_map(batch_index, head_index, kv_seq_index, _):
return (batch_index, head_index, kv_seq_index, 0)

dkv_spec = pl.BlockSpec(dkv_index_map, (1, 1, block_k_major, head_dim))
out_specs = [
dkv_spec, dkv_spec,
pl.BlockSpec(lambda *_: (0, 0), (block_k_major, head_dim)),
pl.BlockSpec(lambda *_: (0, 0), (block_k_major, head_dim)),
out_specs = [dkv_spec, dkv_spec]
scratch_shapes = [
pltpu.VMEM((block_k_major, head_dim), jnp.float32), # type: ignore
pltpu.VMEM((block_k_major, head_dim), jnp.float32), # type: ignore
]

kernel = functools.partial(
Expand All @@ -1094,12 +1087,16 @@ def dkv_index_map(batch_index, head_index, kv_seq_index, _):
)
name_scope = f"flash_mha_bwd_dkv_{block_q_major=}_{block_q=}_{block_k_major=}_{block_k=}"
with jax.named_scope(name_scope):
dk, dv, _, _ = pl.pallas_call(
dk, dv = pl.pallas_call(
kernel,
in_specs=in_specs, # type: ignore
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
grid=grid,
in_specs=in_specs, # type: ignore
out_specs=out_specs,
scratch_shapes=scratch_shapes,
),
out_shape=out_shapes,
out_specs=out_specs,
grid=grid,
debug=debug,
mosaic_params=dict(
dimension_semantics=(
Expand Down Expand Up @@ -1127,8 +1124,8 @@ def _flash_attention_dq_kernel(
do_tile_ref,
di_tile_ref,
dq_tile_ref,
dq_scratch_ref,
ds_tile_ref,
dq_scratch_ref,
*,
sm_scale: float,
causal: bool,
Expand Down Expand Up @@ -1414,15 +1411,14 @@ def kv_segment_ids_index_map(

out_shapes = [
jax.ShapeDtypeStruct(q.shape, q.dtype),
jax.ShapeDtypeStruct((block_q_major, head_dim), jnp.float32),
jax.ShapeDtypeStruct(ab.shape, ab.dtype) if ab is not None else None,
]
dq_spec = pl.BlockSpec(qo_index_map, (1, 1, block_q_major, head_dim))
out_specs = [
dq_spec,
pl.BlockSpec(lambda *_: (0, 0), (block_q_major, head_dim)),
dab_spec,
]
scratch_shapes = [pltpu.VMEM((block_q_major, head_dim), jnp.float32)] # type: ignore

kernel = functools.partial(
_flash_attention_dq_kernel,
Expand All @@ -1434,12 +1430,16 @@ def kv_segment_ids_index_map(
)
name_scope = f"flash_mha_bwd_dq_{block_q_major=}_{block_k_major=}_{block_k=}"
with jax.named_scope(name_scope):
dq, _, ds = pl.pallas_call(
dq, ds = pl.pallas_call(
kernel,
in_specs=in_specs, # type: ignore
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
grid=grid,
in_specs=in_specs, # type: ignore
out_specs=out_specs, # type: ignore
scratch_shapes=scratch_shapes,
),
out_shape=out_shapes,
out_specs=out_specs, # type: ignore
grid=grid,
debug=debug,
mosaic_params=dict(
dimension_semantics=(
Expand Down

0 comments on commit 51a31e5

Please sign in to comment.