Skip to content

Commit

Permalink
Properly test non-contiguous input tensors. (#34)
Browse files Browse the repository at this point in the history
Co-authored-by: Ivan Yashchuk <IvanYashchuk@users.noreply.github.com>
  • Loading branch information
wujingyue and IvanYashchuk authored Mar 22, 2024
1 parent bd17810 commit b8074a0
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions thunder/tests/test_cudnn_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,10 @@ def grad_scaled_dot_product_attention_reference_generator(op, device, dtype, req
q, k, v = make(N, n_head, L, E), make(N, n_head, S, E), make(N, n_head, S, Ev)
yield SampleInput(q, k, v, None, dropout_p=0.0, is_causal=True)

# TODO: cudnnex seems to have a few mismatches. Will be enabled in a later PR.
# Non-contiguous input tensor case
nq = make(N, n_head, L, E).permute(0, 1, 3, 2)
nk = make(N, n_head, L, E).permute(0, 1, 3, 2)
nv = make(N, n_head, L, E).permute(0, 1, 3, 2)
nq = make(N, n_head, E, L).permute(0, 1, 3, 2)
nk = make(N, n_head, E, S).permute(0, 1, 3, 2)
nv = make(N, n_head, Ev, S).permute(0, 1, 3, 2)
yield SampleInput(nq, nk, nv, None, dropout_p=0.0, is_causal=False)

# Test the scale factor which was added in torch 2.1
Expand Down

0 comments on commit b8074a0

Please sign in to comment.