Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/cudnn-frontend
Submodule cudnn-frontend updated 108 files
63 changes: 52 additions & 11 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def test_dot_product_attention(

# UnfusedDotProductAttention backend
if unfused_attn_supported:
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
unfused_attn_fwd, unfused_max_score, unfused_attn_bwd = _run_dot_product_attention(
dtype,
config,
"UnfusedDotProductAttention",
Expand All @@ -180,7 +180,7 @@ def test_dot_product_attention(
# FusedAttention backend
if fused_attn_supported:
if len(fused_attn_backends) == 1:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
fused_attn_fwd, fused_max_score, fused_attn_bwd = _run_dot_product_attention(
dtype,
config,
"FusedAttention",
Expand All @@ -192,7 +192,7 @@ def test_dot_product_attention(
)
if len(fused_attn_backends) == 2:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
fused_attn_fwd, fused_max_score, fused_attn_bwd = _run_dot_product_attention(
dtype,
config,
"FusedAttention",
Expand All @@ -203,7 +203,7 @@ def test_dot_product_attention(
is_training,
)
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention(
fused_attn_fwd_1, fused_max_score_1, fused_attn_bwd_1 = _run_dot_product_attention(
dtype,
config,
"FusedAttention",
Expand All @@ -216,7 +216,7 @@ def test_dot_product_attention(

# FlashAttention backend
if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
flash_attn_fwd, flash_max_score, flash_attn_bwd = _run_dot_product_attention(
dtype,
config,
"FlashAttention",
Expand All @@ -232,16 +232,22 @@ def test_dot_product_attention(
if unfused_attn_supported and flash_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs flash attn")
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
if config.return_max_score:
torch.testing.assert_close(flash_max_score, unfused_max_score, **tols)
for i, _ in enumerate(flash_attn_bwd):
torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols)
if unfused_attn_supported and fused_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs fused attn")
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
if config.return_max_score:
torch.testing.assert_close(fused_max_score, unfused_max_score, **tols)
for i, _ in enumerate(unfused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
if fused_attn_supported and flash_attn_supported:
logging.info("[test_dot_product_attention]: fused attn vs flash attn")
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
if config.return_max_score:
torch.testing.assert_close(fused_max_score, flash_max_score, **tols)
for i, _ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
if fused_attn_supported and len(fused_attn_backends) == 2:
Expand All @@ -260,6 +266,34 @@ def test_dpa_checkpoint(dtype, model_configs, model):
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)


model_configs_max_score = {
# test: b, h, hg, d
"max_score_1_0": ModelConfig(8, 128, 16, 64),
"max_score_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256),
"max_score_2_0": ModelConfig(2, 2048, 24, 128),
"max_score_2_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096),
"max_score_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048),
"max_score_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048),
"max_score_4_0": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048),
"max_score_4_1": ModelConfig(8, 128, 16, 192, max_seqlen_kv=2048),
"max_score_5_0": ModelConfig(8, 1, 16, 512, max_seqlen_kv=2048),
"max_score_5_1": ModelConfig(8, 128, 16, 512, max_seqlen_kv=2048),
"max_score_6_0": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048),
"max_score_6_1": ModelConfig(8, 128, 16, 1024, max_seqlen_kv=2048),
}


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_max_score])
@pytest.mark.parametrize("model", model_configs_max_score.keys())
def test_dpa_max_score(dtype, model_configs, model):
"""Test DotProductAttention module with checkpointing"""
config = model_configs[model]
config.return_max_score = True
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)


model_configs_softmax = {
# test: ModelConfig(b, sq, hq, dqk)
"softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),
Expand Down Expand Up @@ -1065,6 +1099,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
layer_number=1,
attention_type=config.attn_type,
softmax_type=config.softmax_type,
return_max_score=config.return_max_score,
).to(dtype=dtype, device="cuda")
if not is_training:
block = block.eval()
Expand All @@ -1082,7 +1117,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
k = inp[1]
v = inp[2]
d_out = out_grad
out = block(
out, max_score = block(
q,
k,
v,
Expand All @@ -1103,15 +1138,21 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
fast_zero_fill=True,
)
if is_training:
out.backward(d_out)
out.backward((d_out, torch.zeros(1, device="cuda")))

if config.return_max_score:
out = (out, max_score)
else:
out = (out, None)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes returned values when max score is not requested. Should we keep backward compatibility if new feature is not used?


d_softmax_offset = None
if is_training and config.softmax_type != "vanilla":
d_softmax_offset = block.softmax_offset.grad
if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
if is_training:
return out, (q.grad, k.grad, v.grad, d_softmax_offset)
return *out, (q.grad, k.grad, v.grad, d_softmax_offset)
else:
return out, (None, None, None, d_softmax_offset)
return *out, (None, None, None, d_softmax_offset)
if backend == "FusedAttention":
if qkv_format == "thd" and pad_between_seqs:
out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
Expand Down Expand Up @@ -1145,9 +1186,9 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
return out_orig, (None, None, None, d_softmax_offset)
else:
if is_training:
return out, (q.grad, k.grad, v.grad, d_softmax_offset)
return *out, (q.grad, k.grad, v.grad, d_softmax_offset)
else:
return out, (None, None, None, d_softmax_offset)
return *out, (None, None, None, d_softmax_offset)


model_configs_te_layer = {
Expand Down
3 changes: 3 additions & 0 deletions tests/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def __init__(
window_size: Tuple[int, int] = (-1, -1),
context_parallel: bool = False,
cp_comm_type: str = "p2p",
return_max_score=False,
total_requests: int = None,
max_ctx_len: int = None,
num_layers: int = 1,
Expand Down Expand Up @@ -233,6 +234,7 @@ def __init__(
self.window_size = check_set_window_size(self.attn_mask_type, window_size)
self.context_parallel = context_parallel
self.cp_comm_type = cp_comm_type
self.return_max_score = return_max_score
self.total_requests = total_requests
self.max_ctx_len = max_ctx_len
self.num_layers = num_layers
Expand Down Expand Up @@ -318,6 +320,7 @@ def test():
is_training=is_training,
inference_params=inference_params,
softmax_type=config.softmax_type,
return_max_score=config.return_max_score,
)
(
use_flash_attention,
Expand Down
Loading