Skip to content

Commit

Permalink
Set cudnn executor tensors to query device (#2120)
Browse files Browse the repository at this point in the history
Co-authored-by: Vedaanta Agarwalla <142048820+vedaanta-nvidia@users.noreply.github.com>
Co-authored-by: Vedaanta Agarwalla <vagarwalla@ipp2-1950.nvidia.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Mike Ruberry <38511765+mruberry@users.noreply.github.com>
  • Loading branch information
5 people authored Feb 13, 2024
1 parent 05c44da commit 41625e3
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 29 deletions.
15 changes: 13 additions & 2 deletions thunder/executors/cudnn_layernormex.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,19 @@ def cudnn_available() -> bool:
import thunder.core.dtypes as dtypes
from thunder.core.proxies import TensorProxy

from thunder.executors.cudnnex import CudnnTensorAttributes
from thunder.executors.cudnnex import make_cacheable_cudnn_graph_inputs, torch_to_cudnn_dtype
from thunder.executors.cudnnex import CudnnTensorAttributes, torch_to_cudnn_dtype


def make_cacheable_cudnn_graph_inputs(func):
def wrapper(*args, **kwargs):
cudnn_input_args = [
CudnnTensorAttributes(arg.size(), arg.stride(), arg.dtype) if isinstance(arg, torch.Tensor) else arg
for arg in args
]
return func(*cudnn_input_args, **kwargs)

return wrapper


from thunder.extend import OperatorExecutor, register_executor

Expand Down
42 changes: 15 additions & 27 deletions thunder/executors/cudnnex.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
import numpy as np
import random

from lightning_utilities.core.imports import package_available

Expand Down Expand Up @@ -53,17 +54,6 @@ class CudnnTensorAttributes:
dtype: torch.dtype


def make_cacheable_cudnn_graph_inputs(func):
def wrapper(*args, **kwargs):
cudnn_input_args = [
CudnnTensorAttributes(arg.size(), arg.stride(), arg.dtype) if isinstance(arg, torch.Tensor) else arg
for arg in args
]
return func(*cudnn_input_args, **kwargs)

return wrapper


from collections import OrderedDict


Expand Down Expand Up @@ -159,24 +149,16 @@ def _make_cudnn_sdpa_forward_graph(query, key, value, attn_mask, dropout_p, is_c
graph.check_support()
graph.build_plans(cudnn.build_plan_policy.HEURISTICS_CHOICE)

workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)

seed_device_tensor = torch.full((1, 1, 1, 1), 123456, dtype=torch.int32, device="cuda")
offset_device_tensor = torch.full((1, 1, 1, 1), 1, dtype=torch.int32, device="cuda")

_cudnnex_cache[cache_key] = (
Q,
K,
V,
Attn_scale,
Bias,
Seed,
seed_device_tensor,
Offset,
offset_device_tensor,
O,
softmax_stats,
workspace,
graph,
)
return _cudnnex_cache[cache_key]
Expand Down Expand Up @@ -280,19 +262,26 @@ def _cudnn_sdpa_fwd_impl(
Attn_scale,
Bias,
Seed,
seed_tensor,
Offset,
offset_tensor,
O,
softmax_stats,
workspace,
graph,
) = _make_cudnn_sdpa_forward_graph(query_4d, key_4d, value_4d, attn_mask_4d, dropout_p, is_causal)

b, h, s_q, d_q = query.size()
_, _, _, d_v = value.size()
O_actual = torch.empty(b, h, s_q, d_v, dtype=value.dtype, device="cuda")
softmax_stats_actual = torch.empty(b, h, s_q, 1, dtype=torch.float32, device="cuda")
O_actual = torch.empty(b, h, s_q, d_v, dtype=value.dtype, device=query.device)
softmax_stats_actual = torch.empty(b, h, s_q, 1, dtype=torch.float32, device=query.device)
workspace = torch.empty(graph.get_workspace_size(), device=query.device, dtype=torch.uint8)

seed_tensor = (
torch.full((1, 1, 1, 1), random.randint(0, 123902390), dtype=torch.int32, device=query.device) if Seed else None
)
offset_tensor = (
torch.full((1, 1, 1, 1), random.randint(0, 123902390), dtype=torch.int32, device=query.device)
if Offset
else None
)

# Default value of scale, if not provided, in all torch versions
if scale is None:
Expand Down Expand Up @@ -468,7 +457,6 @@ def _make_cudnn_sdpa_backward_graph(query, key, value, attn_mask, dropout_p, is_
graph.check_support()
graph.build_plans(cudnn.build_plan_policy.HEURISTICS_CHOICE)

workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
_cudnnex_cache[cache_key] = (
Q,
K,
Expand All @@ -484,7 +472,6 @@ def _make_cudnn_sdpa_backward_graph(query, key, value, attn_mask, dropout_p, is_
dK,
dV,
dBias,
workspace,
graph,
)
return _cudnnex_cache[cache_key]
Expand Down Expand Up @@ -548,7 +535,6 @@ def cudnn_sdpa_bwd_impl(
dK,
dV,
dBias,
workspace,
graph,
) = _make_cudnn_sdpa_backward_graph(
query_4d,
Expand Down Expand Up @@ -597,6 +583,8 @@ def cudnn_sdpa_bwd_impl(
cudnn_to_torch_tensor[Bias] = attn_mask.detach()
cudnn_to_torch_tensor[dBias] = grad_attn_mask

workspace = torch.empty(graph.get_workspace_size(), device=query.device, dtype=torch.uint8)

graph.execute(cudnn_to_torch_tensor, workspace)

if attn_mask is None:
Expand Down

0 comments on commit 41625e3

Please sign in to comment.