Skip to content

Commit

Permalink
Removed redundant options like flash_attn
Browse files Browse the repository at this point in the history
  • Loading branch information
ardagoreci committed Feb 5, 2024
1 parent 65a41a0 commit c646298
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 126 deletions.
206 changes: 94 additions & 112 deletions src/models/components/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,12 @@
)

deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
ds4s_is_installed = deepspeed_is_installed and importlib.util.find_spec("deepspeed.ops.deepspeed4science") is not None
if deepspeed_is_installed:
import deepspeed

# fa_is_installed = importlib.util.find_spec("flash_attn") is not None
# if fa_is_installed:
# from flash_attn.bert_padding import unpad_input, pad_input
# from flash_attn.flash_attention import FlashAttention
# from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func

if ds4s_is_installed:
from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention

DEFAULT_LMA_Q_CHUNK_SIZE = 1024
DEFAULT_LMA_KV_CHUNK_SIZE = 4096
Expand Down Expand Up @@ -193,9 +190,9 @@ def forward(self, x):
d = x.dtype
deepspeed_is_initialized = (
deepspeed_is_installed and
deepspeed.utils.is_initialized()
deepspeed.comm.comm.is_initialized()
)
if (d is torch.bfloat16 and not deepspeed_is_initialized):
if d is torch.bfloat16 and not deepspeed_is_initialized:
with torch.cuda.amp.autocast(enabled=False):
out = nn.functional.layer_norm(
x,
Expand Down Expand Up @@ -225,9 +222,10 @@ def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
d = t.dtype
deepspeed_is_initialized = (
deepspeed_is_installed and
deepspeed.utils.is_initialized()
deepspeed.comm.comm.is_initialized()

)
if (d is torch.bfloat16 and not deepspeed_is_initialized):
if d is torch.bfloat16 and not deepspeed_is_initialized:
with torch.cuda.amp.autocast(enabled=False):
s = torch.nn.functional.softmax(t, dim=dim)
else:
Expand Down Expand Up @@ -259,7 +257,7 @@ def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, bias
def _attention_chunked_trainable(
query, key, value, biases, chunk_size, chunk_dim, checkpoint,
):
if (checkpoint and len(biases) > 2):
if checkpoint and len(biases) > 2:
raise ValueError(
"Checkpointed version permits only permits two bias terms"
)
Expand Down Expand Up @@ -375,7 +373,8 @@ def __init__(

def _prep_qkv(self,
q_x: torch.Tensor,
kv_x: torch.Tensor
kv_x: torch.Tensor,
apply_scale: bool = True
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor
]:
Expand All @@ -394,7 +393,8 @@ def _prep_qkv(self,
k = k.transpose(-2, -3)
v = v.transpose(-2, -3)

q /= math.sqrt(self.c_hidden)
if apply_scale:
q /= math.sqrt(self.c_hidden)

return q, k, v

Expand Down Expand Up @@ -422,12 +422,10 @@ def forward(
q_x: torch.Tensor,
kv_x: torch.Tensor,
biases: Optional[List[torch.Tensor]] = None,
use_memory_efficient_kernel: bool = False,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
lma_q_chunk_size: int = DEFAULT_LMA_Q_CHUNK_SIZE,
lma_kv_chunk_size: int = DEFAULT_LMA_KV_CHUNK_SIZE,
use_flash: bool = False,
flash_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
Expand All @@ -437,14 +435,14 @@ def forward(
[*, K, C_k] key data
biases:
List of biases that broadcast to [*, H, Q, K]
use_memory_efficient_kernel:
Whether to use a custom memory-efficient attention kernel.
This should be the default choice for most. If none of the
"use_<...>" flags are True, a stock PyTorch implementation
is used instead
use_deepspeed_evo_attention:
Whether to use DeepSpeed memory-efficient attention kernel.
If none of the "use_<...>" flags are True, a stock PyTorch
implementation is used instead
use_lma:
Whether to use low-memory attention (Staats & Rabe 2021). If
none of the "use_<...>" flags are True, a stock PyTorch
Whether to use low-memory attention (Staats & Rabe 2021). It is
advantageous during inference with extremely long sequences.
If none of the "use_<...>" flags are True, a stock PyTorch
implementation is used instead
lma_q_chunk_size:
Query chunk size (for LMA)
Expand All @@ -459,13 +457,7 @@ def forward(
"lma_kv_chunk_size must be provided"
)

if use_flash and biases is not None:
raise ValueError(
"use_flash is incompatible with the bias option. For masking, "
"use flash_mask instead"
)

attn_options = [use_memory_efficient_kernel, use_lma, use_flash]
attn_options = [use_deepspeed_evo_attention, use_lma]
if sum(attn_options) > 1:
raise ValueError(
"Choose at most one alternative attention algorithm"
Expand All @@ -474,34 +466,24 @@ def forward(
if biases is None:
biases = []

# [*, H, Q/K, C_hidden]
q, k, v = self._prep_qkv(q_x, kv_x)
# DeepSpeed attention kernel applies scaling internally
q, k, v = self._prep_qkv(q_x, kv_x,
apply_scale=not use_deepspeed_evo_attention)

# [*, Q, H, C_hidden]
if is_fp16_enabled():
use_memory_efficient_kernel = False

if use_memory_efficient_kernel:
if use_deepspeed_evo_attention:
if len(biases) > 2:
raise ValueError(
"If use_memory_efficient_kernel is True, you may only "
"If use_deepspeed_evo_attention is True, you may only "
"provide up to two bias terms"
)
raise NotImplementedError(
"Memory-efficient kernel is not implemented. This had to be removed "
"because it was creating problems with attn_core_inplace_cuda."
)
# o = attention_core(q, k, v, *((biases + [None] * 2)[:2]))
# o = o.transpose(-2, -3)
o = _deepspeed_evo_attn(q, k, v, biases)
elif use_lma:
biases = [
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],))
for b in biases
]
o = _lma(q, k, v, biases, lma_q_chunk_size, lma_kv_chunk_size)
o = o.transpose(-2, -3)
elif use_flash:
o = _flash_attn(q, k, v, flash_mask)
else:
o = _attention(q, k, v, biases)
o = o.transpose(-2, -3)
Expand Down Expand Up @@ -600,6 +582,72 @@ def forward(self,
return m


@torch.jit.ignore
def _deepspeed_evo_attn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
biases: List[torch.Tensor],
):
"""""
Compute attention using the DeepSpeed DS4Sci_EvoformerAttention kernel.
Args:
q:
[*, H, Q, C_hidden] query data
k:
[*, H, K, C_hidden] key data
v:
[*, H, V, C_hidden] value data
biases:
List of biases that broadcast to [*, H, Q, K]
"""

if not ds4s_is_installed:
raise ValueError(
"_deepspeed_evo_attn requires that DeepSpeed be installed "
"and that the deepspeed.ops.deepspeed4science package exists"
)

def reshape_dims(x):
no_batch_dims = len(x.shape[:-3])
if no_batch_dims < 2:
return x.reshape(*((1,) * (2 - no_batch_dims) + x.shape))
if no_batch_dims > 2:
return x.reshape(*((x.shape[0], -1) + x.shape[-3:]))
return x

# [*, Q/K, H, C_hidden]
q = q.transpose(-2, -3)
k = k.transpose(-2, -3)
v = v.transpose(-2, -3)

# Reshape tensors to match expected input shape [B, N, Q/K, H, C_hidden]
# for DS4Sci_EvoformerAttention() by adding or flattening batch dims as needed.
orig_shape = q.shape
if len(orig_shape[:-3]) != 2:
q = reshape_dims(q)
k = reshape_dims(k)
v = reshape_dims(v)
biases = [reshape_dims(b) for b in biases]

# DeepSpeed attn. kernel requires inputs to be type bf16 or fp16
# Cast to bf16 so kernel can be used during inference
orig_dtype = q.dtype
if orig_dtype not in [torch.bfloat16, torch.float16]:
o = DS4Sci_EvoformerAttention(q.to(dtype=torch.bfloat16),
k.to(dtype=torch.bfloat16),
v.to(dtype=torch.bfloat16),
[b.to(dtype=torch.bfloat16) for b in biases])

o = o.to(dtype=orig_dtype)
else:
o = DS4Sci_EvoformerAttention(q, k, v, biases)

o = o.reshape(orig_shape)
return o


def _lma(
q: torch.Tensor,
k: torch.Tensor,
Expand Down Expand Up @@ -660,69 +708,3 @@ def _lma(
o[..., q_s: q_s + q_chunk_size, :] = q_chunk_out

return o


@torch.jit.ignore
def _flash_attn(q, k, v, kv_mask):
if not fa_is_installed:
raise ValueError(
"_flash_attn requires that FlashAttention be installed"
)

batch_dims = q.shape[:-3]
no_heads, n, c = q.shape[-3:]
dtype = q.dtype

q = q.half()
k = k.half()
v = v.half()
kv_mask = kv_mask.half()

# [*, B, N, H, C]
q = q.transpose(-2, -3)
k = k.transpose(-2, -3)
v = v.transpose(-2, -3)

# [B_flat, N, H, C]
q = q.reshape(-1, *q.shape[-3:])
k = k.reshape(-1, *k.shape[-3:])
v = v.reshape(-1, *v.shape[-3:])

# Flattened batch size
batch_size = q.shape[0]

# [B_flat * N, H, C]
q = q.reshape(-1, *q.shape[-2:])

q_max_s = n
q_cu_seqlens = torch.arange(
0, (batch_size + 1) * n, step=n, dtype=torch.int32, device=q.device
)

# [B_flat, N, 2, H, C]
kv = torch.stack([k, v], dim=-3)
kv_shape = kv.shape

# [B_flat, N, 2 * H * C]
kv = kv.reshape(*kv.shape[:-3], -1)

kv_unpad, _, kv_cu_seqlens, kv_max_s = unpad_input(kv, kv_mask)
kv_unpad = kv_unpad.reshape(-1, *kv_shape[-3:])

out = flash_attn_unpadded_kvpacked_func(
q,
kv_unpad,
q_cu_seqlens,
kv_cu_seqlens,
q_max_s,
kv_max_s,
dropout_p=0.,
softmax_scale=1., # q has been scaled already
)

# [*, B, N, H, C]
out = out.reshape(*batch_dims, n, no_heads, c)

out = out.to(dtype=dtype)

return out
29 changes: 22 additions & 7 deletions src/models/components/triangular_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ def _chunk(self,
x: torch.Tensor,
biases: List[torch.Tensor],
chunk_size: int,
use_memory_efficient_kernel: bool = False,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
) -> torch.Tensor:
"triangle! triangle!"
# triangle! triangle!
mha_inputs = {
"q_x": x,
"kv_x": x,
Expand All @@ -76,7 +76,7 @@ def _chunk(self,
return chunk_layer(
partial(
self.mha,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma
),
mha_inputs,
Expand All @@ -89,14 +89,29 @@ def forward(self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
) -> torch.Tensor:
"""
Args:
x:
[*, I, J, C_in] input tensor (e.g. the pair representation)
mask:
[*, I, J] mask tensor
chunk_size:
The number of sub-batches per chunk. If multiple batch
dimensions are specified, a "sub-batch" is defined as a single
indexing of all batch dimensions simultaneously (s.t. the
number of sub-batches is the product of the batch dimensions).
use_deepspeed_evo_attention:
whether to use DeepSpeed's EvoFormer attention
use_lma:
whether to use low-memory attention, mutually exclusive with
use_deepspeed_evo_attention
inplace_safe:
in-place attention during inference and training
Returns:
[*, I, J, C_in] output tensor
"""
Expand All @@ -106,7 +121,7 @@ def forward(self,
x.shape[:-1],
)

if (not self.starting):
if not self.starting:
x = x.transpose(-2, -3)
mask = mask.transpose(-1, -2)

Expand All @@ -129,7 +144,7 @@ def forward(self,
x,
biases,
chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
Expand All @@ -138,7 +153,7 @@ def forward(self,
q_x=x,
kv_x=x,
biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_lma=use_lma
)

Expand Down
Loading

0 comments on commit c646298

Please sign in to comment.