Skip to content

Commit

Permalink
Merge branch 'main' into te_ddp
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 authored Apr 3, 2024
2 parents 3e2d330 + 3bde9e6 commit bf6d275
Show file tree
Hide file tree
Showing 11 changed files with 505 additions and 111 deletions.
15 changes: 15 additions & 0 deletions thunder/core/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1360,6 +1360,20 @@ def impl(x):
return _interpret_call(impl, x)


# https://docs.python.org/3/library/functions.html#enumerate
def _enumerate_lookaside(obj: Iterable, start: int = 0):
if not wrapped_isinstance(start, int):
return do_raise(TypeError(f"{type(start)} object cannot be interpreted as an integer"))

def impl(obj, start):
n = start
for elem in obj:
yield n, elem
n += 1

return _interpret_call(impl, obj, wrap_const(start))


@interpreter_needs_wrap
def eval_lookaside(
source: str | bytes | bytearray | CodeType, # A python expression
Expand Down Expand Up @@ -2651,6 +2665,7 @@ def create_namedtuple(typename: str, field_names: str, **kwargs):
# Python builtin lookasides
any: _any_lookaside,
bool: _bool_lookaside,
enumerate: _enumerate_lookaside,
exec: exec_lookaside,
eval: eval_lookaside,
getattr: _getattr_lookaside,
Expand Down
54 changes: 54 additions & 0 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ class PrimIDs(Enum):
EMBEDDING_BACKWARD = auto()
LINEAR = auto()
PAD = auto()
BATCH_NORM = auto()
# Memory access methods
ITEM = auto()

Expand Down Expand Up @@ -3532,3 +3533,56 @@ def embedding_backward_meta(grad, indices, num_weights, padding_idx, scale_grad_


embedding_backward = make_prim(PrimIDs.EMBEDDING_BACKWARD, "embedding_backward", meta=embedding_backward_meta)


def batch_norm_meta(
a: TensorProxy,
/,
weight: None | TensorProxy,
bias: None | TensorProxy,
running_mean: None | TensorProxy,
running_var: None | TensorProxy,
training: bool,
momentum: Number,
eps: Number,
) -> tuple[TensorProxy, None | TensorProxy, None | TensorProxy]:
# Checks types
utils.check_type(a, TensorProxy)
utils.check_type(momentum, Number)
utils.check_type(eps, Number)

utils.check(a.ndim >= 2, lambda: f"Input tensor must have at least batch and channel dimensions!")
if not training:
utils.check(
running_mean is not None and running_var is not None,
lambda: f"running_mean and running_var must be defined in evaluation mode",
)

num_features = a.shape[1]

def check_type_device_shape(param, param_name):
utils.check_type(param, TensorProxy)
utils.check_same_device(a, param)
utils.check(
param.shape == (num_features,),
lambda: f"Expected {param_name}.shape={param.shape} to be {(num_features,)}!",
)

if weight is not None:
check_type_device_shape(weight, "weight")
utils.check_same_dtype(a, weight)
if bias is not None:
check_type_device_shape(bias, "bias")
utils.check_same_dtype(a, bias)
if running_mean is not None:
check_type_device_shape(running_mean, "running_mean")
if running_var is not None:
check_type_device_shape(running_var, "running_var")
return (
TensorProxy(like=a),
(TensorProxy(like=a, shape=(num_features,)) if running_mean is None else TensorProxy(like=running_mean)),
(TensorProxy(like=a, shape=(num_features,)) if running_var is None else TensorProxy(like=running_var)),
)


batch_norm = make_prim(PrimIDs.BATCH_NORM, "batch_norm", meta=batch_norm_meta, tags=(OpTags.REDUCTION_OP,))
37 changes: 37 additions & 0 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3058,6 +3058,43 @@ def embedding_backward(a, num_weights, padding_idx, scale_grad_by_freq, sparse,
return gweight


@register_augmented_forward(prims.PrimIDs.BATCH_NORM)
def batch_norm_aug_fwd(
a: TensorProxy,
weight: None | TensorProxy,
bias: None | TensorProxy,
running_mean: None | TensorProxy,
running_var: None | TensorProxy,
training: bool,
momentum: Number,
eps: Number,
) -> VJPDual:
primal = prims.batch_norm(
a,
weight,
bias,
running_mean,
running_var,
training,
momentum,
eps,
)
output_mask = [x is not None for x in (a, weight, bias)]
output, save_mean, save_invstd = primal
residuals = (a, weight, running_mean, running_var, save_mean, save_invstd, training, eps, output_mask)
return VJPDual(primal, residuals)


@register_backward(prims.PrimIDs.BATCH_NORM)
def batch_norm_backward(a, weight, running_mean, running_var, save_mean, save_invstd, train, eps, output_mask, *grads):
from thunder.torch import batch_norm_backward

result = batch_norm_backward(
grads[0], a, weight, running_mean, running_var, save_mean, save_invstd, train, eps, output_mask
)
return *result, None, None


@register_augmented_forward("torch.cumsum")
def cumsum_aug_fwd(a: Proxy, dim: int, *, dtype: None | dtypes.dtype = None) -> VJPDual:
from thunder.torch import cumsum
Expand Down
14 changes: 10 additions & 4 deletions thunder/executors/cudnn_layernormex.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
def make_cacheable_cudnn_graph_inputs(func):
def wrapper(*args, **kwargs):
cudnn_input_args = [
CudnnTensorAttributes(arg.size(), arg.stride(), arg.dtype) if isinstance(arg, torch.Tensor) else arg
(
CudnnTensorAttributes(arg.size(), arg.stride(), arg.dtype, args.device_index)
if isinstance(arg, torch.Tensor)
else arg
)
for arg in args
]
return func(*cudnn_input_args, **kwargs)
Expand Down Expand Up @@ -84,9 +88,11 @@ def _transform_layer_norm_inputs(a, normalized_shape, weight, bias):

# Assume strides to be NCHW contiguous
assumed_stride = (elements_to_normalize, 1, 1, 1)
a_4d = CudnnTensorAttributes((batch_size, elements_to_normalize, 1, 1), assumed_stride, a.dtype)
weight_4d = CudnnTensorAttributes((1, elements_to_normalize, 1, 1), assumed_stride, weight.dtype)
bias_4d = CudnnTensorAttributes((1, elements_to_normalize, 1, 1), assumed_stride, bias.dtype)
a_4d = CudnnTensorAttributes((batch_size, elements_to_normalize, 1, 1), assumed_stride, a.dtype, a.device.index)
weight_4d = CudnnTensorAttributes(
(1, elements_to_normalize, 1, 1), assumed_stride, weight.dtype, weight.device.index
)
bias_4d = CudnnTensorAttributes((1, elements_to_normalize, 1, 1), assumed_stride, bias.dtype, bias.device.index)

return a_4d, weight_4d, bias_4d

Expand Down
50 changes: 41 additions & 9 deletions thunder/executors/cudnnex.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,30 @@ def cudnn_version() -> int:
import cudnn

cudnn_backend_version = cudnn.backend_version()
cudnn_handle = cudnn.create_handle()
# Mapping from device to cudnn handles
device_to_cudnn_handle = {}


# This function creates a new handle for the device that cudnn should
# run its kernels on. As the suggested approach by cudnn is to make a few handles
# as possible, this function caches these per-device handles.
def _get_cudnn_handle(query_device):
handle = device_to_cudnn_handle.get(query_device, None)
if handle is None:
with torch.cuda.device(query_device):
handle = cudnn.create_handle()
device_to_cudnn_handle[query_device] = handle

# Make sure the user stream is set on the handle
# Fetch the current user stream and pass the data pointer to set_stream API
cudnn.set_stream(handle=handle, stream=torch.cuda.current_stream(device=query_device).cuda_stream)

return handle


# WARNING: cudnn executor is experimental. Tests that use cudnn might fail.\n
# Issue for tracking support: https://github.com/Lightning-AI/lightning-thunder/issues/880~

from dataclasses import dataclass
from functools import lru_cache
from typing import Union, Dict
Expand Down Expand Up @@ -54,6 +75,7 @@ class CudnnTensorAttributes:
size: tuple
stride: tuple
dtype: torch.dtype
device_index: int


from collections import OrderedDict
Expand Down Expand Up @@ -84,7 +106,9 @@ def __setitem__(self, key, value):

def _make_cudnn_sdpa_forward_graph(query, key, value, attn_mask, dropout_p, is_causal):
graph = cudnn.pygraph(
intermediate_data_type=cudnn.data_type.FLOAT, compute_data_type=cudnn.data_type.FLOAT, handle=cudnn_handle
intermediate_data_type=cudnn.data_type.FLOAT,
compute_data_type=cudnn.data_type.FLOAT,
handle=_get_cudnn_handle(query.device_index),
)

Q = graph.tensor(name="Q", dim=query.size, stride=query.stride, data_type=torch_to_cudnn_dtype(query.dtype))
Expand Down Expand Up @@ -191,11 +215,11 @@ def compute_NHWC_strides(shape):
stride *= shape[i]
return tuple(strides)

query_4d = CudnnTensorAttributes(query.shape, compute_NHWC_strides(query.shape), query.dtype)
query_4d = CudnnTensorAttributes(query.shape, compute_NHWC_strides(query.shape), query.dtype, query.device.index)

key_4d = CudnnTensorAttributes(key.shape, compute_NHWC_strides(key.shape), key.dtype)
key_4d = CudnnTensorAttributes(key.shape, compute_NHWC_strides(key.shape), key.dtype, key.device.index)

value_4d = CudnnTensorAttributes(value.shape, compute_NHWC_strides(value.shape), value.dtype)
value_4d = CudnnTensorAttributes(value.shape, compute_NHWC_strides(value.shape), value.dtype, value.device.index)

attn_mask_4d = None
if attn_mask is not None:
Expand All @@ -204,7 +228,9 @@ def compute_NHWC_strides(shape):

# cudnn does not support boolean attn_mask, so make one with -inf
attn_mask_dtype = query.dtype if attn_mask.dtype in [torch.bool, dtypes.bool8] else attn_mask.dtype
attn_mask_4d = CudnnTensorAttributes(attn_mask_shape, compute_NHWC_strides(attn_mask_shape), attn_mask_dtype)
attn_mask_4d = CudnnTensorAttributes(
attn_mask_shape, compute_NHWC_strides(attn_mask_shape), attn_mask_dtype, attn_mask.device.index
)

return query_4d, key_4d, value_4d, attn_mask_4d

Expand Down Expand Up @@ -313,7 +339,10 @@ def _cudnn_sdpa_fwd_impl(
if attn_mask is not None:
cudnn_to_torch_tensor[Bias] = attn_mask.detach()

graph.execute(cudnn_to_torch_tensor, workspace)
# Even though the handle is created on query.device, cudnn still requires to set current device to query.device.
# This is most probably a bug and is being actively looked into.
with torch.cuda.device(query.device):
graph.execute(cudnn_to_torch_tensor, workspace, handle=_get_cudnn_handle(query.device))

return O_actual, softmax_stats_actual, seed_tensor, offset_tensor

Expand Down Expand Up @@ -368,7 +397,7 @@ def _make_cudnn_sdpa_backward_graph(
io_data_type=torch_to_cudnn_dtype(query.dtype),
intermediate_data_type=cudnn.data_type.FLOAT,
compute_data_type=cudnn.data_type.FLOAT,
handle=cudnn_handle,
handle=_get_cudnn_handle(query.device_index),
)

Q = graph.tensor(name="Q", dim=query.size, stride=query.stride, data_type=torch_to_cudnn_dtype(query.dtype))
Expand Down Expand Up @@ -623,7 +652,10 @@ def _cudnn_sdpa_bwd_impl(

workspace = torch.empty(graph.get_workspace_size(), device=query.device, dtype=torch.uint8)

graph.execute(cudnn_to_torch_tensor, workspace)
# Even though the handle is created on query.device, cudnn still requires to set current device to query.device.
# This is most probably a bug and is being actively looked into.
with torch.cuda.device(query.device):
graph.execute(cudnn_to_torch_tensor, workspace, handle=_get_cudnn_handle(query.device))

if cat_grad_qkv:
grads = (grad_qkv,)
Expand Down
40 changes: 40 additions & 0 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2000,6 +2000,46 @@ def var_mean(
register_supported(PrimIDs.VAR_MEAN, var_mean, _var_mean_check)


def _batch_norm_check(
a: TensorProxy,
weight: None | TensorProxy,
bias: None | TensorProxy,
running_mean: None | TensorProxy,
running_var: None | TensorProxy,
training: bool,
momentum: Number,
eps: Number,
) -> bool:
return are_supported_tensors(*(x for x in (a, weight, bias, running_mean, running_var) if x is not None))


def batch_norm(
a: TensorProxy,
weight: None | TensorProxy,
bias: None | TensorProxy,
running_mean: None | TensorProxy,
running_var: None | TensorProxy,
training: bool,
momentum: Number,
eps: Number,
*,
fd: FusionDefinition,
lc_to_nv_map: dict,
) -> Any:
nva = getnv(a, fd, lc_to_nv_map)
nvweight = None if weight is None else getnv(weight, fd, lc_to_nv_map)
nvbias = None if bias is None else getnv(bias, fd, lc_to_nv_map)
nvrunning_mean = None if running_mean is None else getnv(running_mean, fd, lc_to_nv_map)
nvrunning_var = None if running_var is None else getnv(running_var, fd, lc_to_nv_map)
nvmomentum = getnv(momentum, fd, lc_to_nv_map)
nveps = getnv(eps, fd, lc_to_nv_map)

return fd.ops.batch_norm(nva, nvweight, nvbias, nvrunning_mean, nvrunning_var, nvmomentum, nveps, training)


register_supported(PrimIDs.BATCH_NORM, batch_norm, _batch_norm_check)


# Removes excessive float casts, like those that occur when autocasting
# NOTE This passes actually changes a program's semantics, because it will take a sequence like
# fp32 -> fp16 -> fp32 and remove all the operations, but casting fp32 values to fp16 can
Expand Down
8 changes: 8 additions & 0 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,7 @@ def _erfcinv_impl(a: torch.Tensor) -> torch.Tensor:
gelu = _register_torch_operation("gelu", module=torch.nn.functional)
relu = _register_torch_operation("relu", module=torch.nn.functional)
relu6 = _register_torch_operation("relu6", module=torch.nn.functional)
hardswish = _register_torch_operation("hardswish", module=torch.nn.functional)
selu = _register_torch_operation("selu", module=torch.nn.functional)
silu = _register_torch_operation("silu", module=torch.nn.functional)

Expand All @@ -754,6 +755,7 @@ def _elementwise_unary_with_inplace_checker(a: TensorProxy, /, inplace: bool = F
_register_elementwise_unary_implementation(ltorch.gelu, gelu, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.relu, relu, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.relu6, relu6, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.hardswish, hardswish, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.selu, selu, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.silu, silu)

Expand Down Expand Up @@ -1177,9 +1179,15 @@ def _take_along_axis_prim_transform(a: TensorProxy, /, index: TensorProxy, dim:

layer_norm = _register_torch_operation("layer_norm", module=torch.nn.functional)
batch_norm = _register_torch_operation("batch_norm", module=torch.nn.functional)
native_batch_norm = _register_torch_operation("torch.ops.aten.native_batch_norm", like=prims.batch_norm)
native_batch_norm_backward = _register_torch_operation(
"torch.ops.aten.native_batch_norm_backward", like=ltorch.batch_norm_backward
)

_register_implementation(ltorch.layer_norm, layer_norm, checker=_always_executable)
_register_implementation(ltorch.batch_norm, batch_norm, checker=_always_executable)
_register_implementation(prims.batch_norm, native_batch_norm, checker=_always_executable)
_register_implementation(ltorch.batch_norm_backward, native_batch_norm_backward, checker=_always_executable)

#
# NN operations
Expand Down
Loading

0 comments on commit bf6d275

Please sign in to comment.