Skip to content

Commit

Permalink
Max pool (#163)
Browse files Browse the repository at this point in the history
jjsjann123 authored Apr 12, 2024
1 parent f39bea7 commit 709a062
Showing 1 changed file with 152 additions and 2 deletions.
154 changes: 152 additions & 2 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
@@ -37,6 +37,11 @@

from thunder.extend import OperatorExecutor, register_executor, add_always_executor

from thunder.core.transforms import (
get_grad,
put_grad,
)

ex = OperatorExecutor("torch", version=torch.__version__)
register_executor(ex)
add_always_executor(ex)
@@ -1217,6 +1222,100 @@ def _take_along_axis_prim_transform(a: TensorProxy, /, index: TensorProxy, dim:
max_pool1d = _register_torch_operation("max_pool1d", module=torch.nn.functional)
max_pool2d = _register_torch_operation("max_pool2d", module=torch.nn.functional)
max_pool3d = _register_torch_operation("max_pool3d", module=torch.nn.functional)


def _max_pool_with_indices_helper(
ndim: int,
a: TensorProxy,
/,
kernel_size: int | Sequence[int],
stride: int | Sequence[int] | None = None,
padding: int | Sequence[int] = 0,
dilation: int | Sequence[int] = 1,
ceil_mode: bool = False,
) -> [TensorProxy, TensorProxy]:
def div_rtn(x, y):
q = x // y
r = x % y
if r != 0 and (r < 0) != (y < 0):
q -= 1
return q

def pooling_output_shape(in_, kernel_, pad_, stride_, dilation_, ceil_mode_: bool):
out_size = (
div_rtn(in_ + 2 * pad_ - dilation_ * (kernel_ - 1) - 1 + (stride - 1 if ceil_mode else 0), stride) + 1
)
if ceil_mode and (out_size - 1) * stride >= in_ + pad_:
out_size -= 1
return out_size

def get_maybe_ith_entry(arg_name: str, seq: int | Sequence[int], i: int, default: int | None = None):
if seq is None:
return default

if not isinstance(seq, Sequence):
return seq

if len(seq) == 1:
return seq[0]
else:
utils.check(
i < len(seq),
lambda: f"invalid pooling argument: {arg_name} needs to be None / a scalar / size-{ndim} Sequence, but received {seq}",
)
return seq[i]

out_sizes = []
for i in range(ndim):
in_ = a.shape[i - ndim] # i - ndim is the i-th spatial dimension
kernel_ = get_maybe_ith_entry("kernel_size", kernel_size, i)
stride_ = get_maybe_ith_entry("stride", stride, i, kernel_)
pad_ = get_maybe_ith_entry("padding", padding, i)
dilation_ = get_maybe_ith_entry("dilation", dilation, i)
utils.check(
kernel_ is not None and stride_ is not None and pad_ is not None and dilation_ is not None,
lambda: f"max_pool argument extraction failed.",
)
out_sizes.append(pooling_output_shape(in_, kernel_, pad_, stride_, dilation_, ceil_mode))

return TensorProxy(like=a, shape=out_sizes), TensorProxy(like=a, shape=out_sizes)


def max_pool_with_indices_backward_meta(
grad: TensorProxy,
a: TensorProxy,
kernel_size: int | Sequence[int],
stride: int | Sequence[int] | None,
padding: int | Sequence[int],
dilation: int | Sequence[int],
ceil_mode: bool,
result1: TensorProxy,
) -> TensorProxy:
return TensorProxy(like=a)


max_pool2d_with_indices_meta = partial(_max_pool_with_indices_helper, 2)

max_pool2d_with_indices = ex.register_operator(
"max_pool2d_with_indices", meta=max_pool2d_with_indices_meta, fn=torch.ops.aten.max_pool2d_with_indices
)
max_pool2d_with_indices_backward = ex.register_operator(
"max_pool2d_with_indices_backward",
meta=max_pool_with_indices_backward_meta,
fn=torch.ops.aten.max_pool2d_with_indices_backward,
)

max_pool3d_with_indices_meta = partial(_max_pool_with_indices_helper, 3)

max_pool3d_with_indices = ex.register_operator(
"max_pool3d_with_indices", meta=max_pool3d_with_indices_meta, fn=torch.ops.aten.max_pool3d_with_indices
)
max_pool3d_with_indices_backward = ex.register_operator(
"max_pool3d_with_indices_backward",
meta=max_pool_with_indices_backward_meta,
fn=torch.ops.aten.max_pool3d_with_indices_backward,
)

nll_loss = _register_torch_operation("nll_loss", module=torch.nn.functional)
pad = _register_torch_operation("pad", module=torch.nn.functional)
scaled_dot_product_attention = _register_torch_operation("scaled_dot_product_attention", module=torch.nn.functional)
@@ -1461,8 +1560,59 @@ def _pad_prim_impl(
ltorch.log_softmax_backward, checker=_always_executable, execution_transform=_log_softmax_backward_transform
)
_register_implementation(ltorch.max_pool1d, max_pool1d, checker=_always_executable)
_register_implementation(ltorch.max_pool2d, max_pool2d, checker=_always_executable)
_register_implementation(ltorch.max_pool3d, max_pool3d, checker=_always_executable)


def max_pool2d_bwd_wrapper(
a: TensorProxy,
/,
kernel_size: int | Sequence[int],
stride: int | Sequence[int] | None = None,
padding: int | Sequence[int] = 0,
dilation: int | Sequence[int] = 1,
return_indices: bool = False,
ceil_mode: bool = False,
) -> tuple[TensorProxy, TensorProxy] | TensorProxy:
primals = max_pool2d_with_indices(a, kernel_size, stride, padding, dilation, ceil_mode)

grad = get_grad(primals[0])
grad_a = max_pool2d_with_indices_backward(grad, a, kernel_size, stride, padding, dilation, ceil_mode, primals[1])
put_grad(a, grad_a)

if return_indices:
return primals
else:
return primals[0]


def max_pool3d_bwd_wrapper(
a: TensorProxy,
/,
kernel_size: int | Sequence[int],
stride: int | Sequence[int] | None = None,
padding: int | Sequence[int] = 0,
dilation: int | Sequence[int] = 1,
return_indices: bool = False,
ceil_mode: bool = False,
) -> tuple[TensorProxy, TensorProxy] | TensorProxy:
primals = max_pool3d_with_indices(a, kernel_size, stride, padding, dilation, ceil_mode)

grad = get_grad(primals[0])
grad_a = max_pool3d_with_indices_backward(grad, a, kernel_size, stride, padding, dilation, ceil_mode, primals[1])
put_grad(a, grad_a)

if return_indices:
return primals
else:
return primals[0]


# ltorch.max_pool2d/3d decomposition uses convolution, which has performance issue running through torchex. We added grad_transform that keep both forward and backward max_pool as a torch composite operator, which avoids the performance issue. For details: https://github.com/Lightning-AI/lightning-thunder/issues/164. Aten doesn't have explicit functions for max_pool1d fwd/bwd. So the specialization is only done for 2d/3d case.
ex.register_implementation(
ltorch.max_pool2d, max_pool2d, checker=_always_executable, grad_transform=max_pool2d_bwd_wrapper
)
ex.register_implementation(
ltorch.max_pool3d, max_pool3d, checker=_always_executable, grad_transform=max_pool3d_bwd_wrapper
)
_register_implementation(ltorch.nll_loss, checker=_always_executable, execution_transform=_nll_loss_transform)
nll_loss_backward = ex.register_operator(
"torch_nll_loss_backward_impl", meta=ltorch.nll_loss_backward, fn=_nll_loss_backward_impl

0 comments on commit 709a062

Please sign in to comment.