Skip to content

Commit

Permalink
Clean the cudnn test. (PR2477)
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue authored and Borda committed Mar 20, 2024
1 parent d1d7fb0 commit 296f04d
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions thunder/tests/test_cudnn_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,42 +34,42 @@ def grad_scaled_dot_product_attention_reference_generator(op, device, dtype, req
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=-0.5, high=0.5)

n_head = 2
N = 8
N = 8 # batch size

# TODO: multiple of 8 seems to produce NaNs
L = random.randint(1, 10) * 64
L = random.randint(1, 10) * 64 # query's sequence length

alignment_factor = 8
S = random.randint(1, 10) * alignment_factor
E = random.randint(8, 16) * alignment_factor
Ev = random.randint(8, 16) * alignment_factor
S = random.randint(1, 10) * alignment_factor # key/value's sequence length
E = random.randint(8, 16) * alignment_factor # query/key's embedding size
Ev = random.randint(8, 16) * alignment_factor # value's embedding size

# 4-dim (multiheaded) causal cases
q, k, v = make(N, n_head, L, E), make(N, n_head, S, E), make(N, n_head, S, Ev)
yield SampleInput(q, k, v, attn_mask := None, dropout_p := 0.0, is_causal := True)
yield SampleInput(q, k, v, None, dropout_p=0.0, is_causal=True)

# TODO: cudnnex seems to have a few mismatches. Will be enabled in a later PR.
# Non-contiguous input tensor case
nq = make(N, n_head, L, E).permute(0, 1, 3, 2)
nk = make(N, n_head, L, E).permute(0, 1, 3, 2)
nv = make(N, n_head, L, E).permute(0, 1, 3, 2)
yield SampleInput(nq, nk, nv, attn_mask := None, dropout_p := 0.0, is_causal := False)
yield SampleInput(nq, nk, nv, None, dropout_p=0.0, is_causal=False)

# Test the scale factor which was added in torch 2.1
if LooseVersion(torch.__version__) >= LooseVersion("2.1.0"):
q, k, v = make(N, n_head, L, E), make(N, n_head, S, E), make(N, n_head, S, Ev)
yield SampleInput(q, k, v, attn_mask := None, dropout_p := 0.0, is_causal := False, scale=0.123)
yield SampleInput(q, k, v, None, dropout_p=0.0, is_causal=False, scale=0.123)

# TODO: cudnnex only support of grad_attn_mask with batch dim 1 and both sequence lenghts divisible by 64. Release 9.0.1 will relax this constraint.
# Additive attn_mask
q, k, v = make(N, n_head, L, E), make(N, n_head, S, E), make(N, n_head, S, Ev)
additive_attn_mask = make((1, n_head, L, S), dtype=q.dtype).tril()
yield SampleInput(q, k, v, attn_mask := additive_attn_mask, is_causal=False)
yield SampleInput(q, k, v, additive_attn_mask, is_causal=False)

# Boolean attn_mask
q, k, v = make(N, n_head, L, E), make(N, n_head, S, E), make(N, n_head, S, Ev)
bool_attn_mask = make((1, n_head, L, S), dtype=torch.bool, low=1, high=1, requires_grad=False).tril()
yield SampleInput(q, k, v, attn_mask := bool_attn_mask, is_causal=False)
yield SampleInput(q, k, v, bool_attn_mask, is_causal=False)


grad_sdpa_cudnn_opinfo = OpInfo(
Expand Down

0 comments on commit 296f04d

Please sign in to comment.