Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Makes cudnn a default executor #427

Merged
merged 19 commits into from
May 28, 2024
Merged

Makes cudnn a default executor #427

merged 19 commits into from
May 28, 2024

Conversation

vedaanta
Copy link
Collaborator

Before submitting
  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Cudnn is now a default executor.
The only operation targeted is sdpa.

The main change is a stricter checker function. Both forward and backward graph support are ensured before claiming sdpa operation.
(The checker was previously made lenient in #57)

Fixes #418.

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

Copy link
Collaborator

@tfogal tfogal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@parthmannan how do you feel about this from perf perspective? one concern might be the _make_cudnn_sdpa_*_graph expense, but maybe that's 1) cheaper than i worry about, or 2) not that relevant because _cudnn_sdpa_checker doesn't get called all that often anyway ?

thunder/executors/cudnnex.py Show resolved Hide resolved
README.md Show resolved Hide resolved
@vedaanta
Copy link
Collaborator Author

vedaanta commented May 16, 2024

Before:

---------------------------------------------------------------------------------------------------- benchmark: 5 tests ---------------------------------------------------------------------------------------------------
Name (time in ms)                                               Min                Max               Mean            StdDev             Median               IQR            Outliers      OPS            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_llama2_7b_sdpa_grad[thunder+cudnn]                     16.1154 (1.0)      17.0821 (1.0)      16.4272 (1.0)      0.2377 (1.0)      16.4072 (1.0)      0.3396 (1.56)          9;1  60.8745 (1.0)          40           1
test_llama2_7b_sdpa_grad[thunder]                           27.0254 (1.68)     28.7664 (1.68)     27.5484 (1.68)     0.4476 (1.88)     27.3732 (1.67)     0.3508 (1.61)         12;4  36.2998 (0.60)         40           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

After:

---------------------------------------------------------------------------------------------------- benchmark: 5 tests ---------------------------------------------------------------------------------------------------
Name (time in ms)                                               Min                Max               Mean            StdDev             Median               IQR            Outliers      OPS            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_llama2_7b_sdpa_grad[thunder+cudnn]                     16.1055 (1.0)      16.8787 (1.0)      16.4041 (1.0)      0.2354 (1.0)      16.4000 (1.0)      0.4397 (1.05)         18;0  60.9606 (1.0)          40           1
test_llama2_7b_sdpa_grad[thunder]                           16.1180 (1.00)     17.0175 (1.01)     16.4810 (1.00)     0.2475 (1.05)     16.4507 (1.00)     0.4206 (1.0)          17;0  60.6759 (1.00)         40           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

@vedaanta vedaanta requested review from tfogal and Borda May 16, 2024 21:28
Copy link
Collaborator

@tfogal tfogal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would love to hear from someone smarter than me on the requirements.txt question (Jirka? Ivan?), but +1 from me

README.md Show resolved Hide resolved
thunder/executors/cudnnex.py Outdated Show resolved Hide resolved
thunder/__init__.py Show resolved Hide resolved
Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
@vedaanta vedaanta requested a review from wujingyue May 17, 2024 18:30
@parthmannan
Copy link
Collaborator

Thank you!

@parthmannan how do you feel about this from perf perspective? one concern might be the _make_cudnn_sdpa_*_graph expense, but maybe that's 1) cheaper than i worry about, or 2) not that relevant because _cudnn_sdpa_checker doesn't get called all that often anyway ?

@tfogal I don't think we need to worry about this from a performance perspective as you pointed in 2. - I don't expect this to be called very often. Hopefully just the first iteration for static shapes if I am thinking about this correctly.

What happens with dynamic shapes? We eventually plan to support that, can cuDNN create broader graphs that work for many shapes or will we need to call this everytime?

@vedaanta
Copy link
Collaborator Author

@t-vi this is ready for your final review and merge. :)

@tfogal
Copy link
Collaborator

tfogal commented May 21, 2024

@vedaanta I think the test failures are real / caused by this patch:

        jfn = thunder.jit(module)
>       result = jfn(*args, **kwargs)

thunder/tests/test_jit_general.py:608: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1532: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1541: in _call_impl
    return forward_call(*args, **kwargs)
thunder/core/module.py:49: in forward
    res = self._forward_fn(*args, **kwargs)
thunder/__init__.py:626: in fn_
    result = cache_entry.computation_fn(*inps)
/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115: in decorate_context
    return func(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py:28: in decorate_autocast
    return func(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py:28: in decorate_autocast
    return func(*args, **kwargs)
thunder.computation_729:68: in computation
    (y, _, _, _) = cudnn_sdpa_fwd(q, t74, t78, None, 0, True, scale=None)
thunder/executors/cudnnex.py:357: in _cudnn_sdpa_fwd_impl
    with torch.cuda.device(query.device):
/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py:361: in __init__
    self.idx = _get_device_index(device, optional=True)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

device = device(type='cpu'), optional = True, allow_cpu = False

    def _get_device_index(
        device: Any, optional: bool = False, allow_cpu: bool = False
    ) -> int:
        r"""Get the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``.
    
        If :attr:`device` is a torch.device object, returns the device index if it
        is a CUDA device. Note that for a CUDA device without a specified index,

this is ready for your final review and merge. :)

I think we should hold off pending investigation; please take a look and let's double check we're not causing regressions here.

@t-vi
Copy link
Collaborator

t-vi commented May 24, 2024

Yeah, as @tfogal , points out, I think something is up in the tests.

@lantiga lantiga merged commit 7bd637a into main May 28, 2024
37 checks passed
@lantiga lantiga deleted the cudnn/default branch May 28, 2024 22:39
crcrpar pushed a commit that referenced this pull request May 29, 2024
Co-authored-by: Vedaanta Agarwalla <vagarwalla@ipp2-1949.nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
Co-authored-by: Thomas Viehmann <tv@beamnet.de>
Co-authored-by: Luca Antiga <luca@lightning.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Enable cuDNN executor by default
8 participants