Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Max pool #163

Merged
merged 23 commits into from
Apr 12, 2024
Merged
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
1756c03
special grad_transform for max_poolxd
jjsjann123 Apr 11, 2024
9511603
forgot the name
jjsjann123 Apr 11, 2024
f641130
adding torch symbol for max_pool backward
jjsjann123 Apr 11, 2024
75e284c
fixing signature
jjsjann123 Apr 11, 2024
cc489c6
removing 3d because of pytorch aten API coverage
jjsjann123 Apr 11, 2024
486060a
fixing max_pool2d with indices
jjsjann123 Apr 11, 2024
d252018
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2024
d614035
adding torch operator max_pool2d_with_indices
jjsjann123 Apr 11, 2024
104a96e
Merge remote-tracking branch 'jiej/max_pool' into max_pool
jjsjann123 Apr 11, 2024
c995511
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2024
82f56e9
patch backward operator
jjsjann123 Apr 12, 2024
22efa40
patch
jjsjann123 Apr 12, 2024
254fcb0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2024
0e9c441
fixing logic
jjsjann123 Apr 12, 2024
d8647bb
functionally correct now at least
jjsjann123 Apr 12, 2024
e9e9dd9
Merge branch 'main' into max_pool
jjsjann123 Apr 12, 2024
7a06c07
refactor to support max_pool3d as well
jjsjann123 Apr 12, 2024
ef7ae3a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2024
351469d
partial can't be used in grad_transform
jjsjann123 Apr 12, 2024
f8a167f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2024
dfaa058
addressing reviews
jjsjann123 Apr 12, 2024
155ac78
typo
jjsjann123 Apr 12, 2024
2a8da1b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 152 additions & 2 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@

from thunder.extend import OperatorExecutor, register_executor, add_always_executor

from thunder.core.transforms import (
get_grad,
put_grad,
)

ex = OperatorExecutor("torch", version=torch.__version__)
register_executor(ex)
add_always_executor(ex)
Expand Down Expand Up @@ -1217,6 +1222,100 @@ def _take_along_axis_prim_transform(a: TensorProxy, /, index: TensorProxy, dim:
max_pool1d = _register_torch_operation("max_pool1d", module=torch.nn.functional)
max_pool2d = _register_torch_operation("max_pool2d", module=torch.nn.functional)
max_pool3d = _register_torch_operation("max_pool3d", module=torch.nn.functional)


def _max_pool_with_indices_helper(
ndim: int,
a: TensorProxy,
/,
kernel_size: int | Sequence[int],
stride: int | Sequence[int] | None = None,
padding: int | Sequence[int] = 0,
dilation: int | Sequence[int] = 1,
ceil_mode: bool = False,
) -> [TensorProxy, TensorProxy]:
def div_rtn(x, y):
q = x // y
r = x % y
if r != 0 and (r < 0) != (y < 0):
q -= 1
return q

def pooling_output_shape(in_, kernel_, pad_, stride_, dilation_, ceil_mode_: bool):
out_size = (
div_rtn(in_ + 2 * pad_ - dilation_ * (kernel_ - 1) - 1 + (stride - 1 if ceil_mode else 0), stride) + 1
)
if ceil_mode and (out_size - 1) * stride >= in_ + pad_:
out_size -= 1
return out_size

def get_maybe_ith_entry(arg_name: str, seq: int | Sequence[int], i: int, default: int | None = None):
if seq is None:
return default

if not isinstance(seq, Sequence):
return seq

if len(seq) == 1:
return seq[0]
else:
utils.check(
i < len(seq),
lambda: f"invalid pooling argument: {arg_name} needs to be None / a scalar / size-{ndim} Sequence, but received {seq}",
)
return seq[i]

out_sizes = []
for i in range(ndim):
in_ = a.shape[i - ndim] # i - ndim is the i-th spatial dimension
kernel_ = get_maybe_ith_entry("kernel_size", kernel_size, i)
stride_ = get_maybe_ith_entry("stride", stride, i, kernel_)
pad_ = get_maybe_ith_entry("padding", padding, i)
dilation_ = get_maybe_ith_entry("dilation", dilation, i)
utils.check(
kernel_ is not None and stride_ is not None and pad_ is not None and dilation_ is not None,
lambda: f"max_pool argument extraction failed.",
)
out_sizes.append(pooling_output_shape(in_, kernel_, pad_, stride_, dilation_, ceil_mode))
jjsjann123 marked this conversation as resolved.
Show resolved Hide resolved

return TensorProxy(like=a, shape=out_sizes), TensorProxy(like=a, shape=out_sizes)


def max_pool_with_indices_backward_meta(
grad: TensorProxy,
a: TensorProxy,
kernel_size: int | Sequence[int],
stride: int | Sequence[int] | None,
padding: int | Sequence[int],
dilation: int | Sequence[int],
ceil_mode: bool,
result1: TensorProxy,
) -> TensorProxy:
return TensorProxy(like=a)


max_pool2d_with_indices_meta = partial(_max_pool_with_indices_helper, 2)

max_pool2d_with_indices = ex.register_operator(
"max_pool2d_with_indices", meta=max_pool2d_with_indices_meta, fn=torch.ops.aten.max_pool2d_with_indices
)
max_pool2d_with_indices_backward = ex.register_operator(
"max_pool2d_with_indices_backward",
meta=max_pool_with_indices_backward_meta,
fn=torch.ops.aten.max_pool2d_with_indices_backward,
)

max_pool3d_with_indices_meta = partial(_max_pool_with_indices_helper, 3)

max_pool3d_with_indices = ex.register_operator(
"max_pool3d_with_indices", meta=max_pool3d_with_indices_meta, fn=torch.ops.aten.max_pool3d_with_indices
)
max_pool3d_with_indices_backward = ex.register_operator(
"max_pool3d_with_indices_backward",
meta=max_pool_with_indices_backward_meta,
fn=torch.ops.aten.max_pool3d_with_indices_backward,
)

nll_loss = _register_torch_operation("nll_loss", module=torch.nn.functional)
pad = _register_torch_operation("pad", module=torch.nn.functional)
scaled_dot_product_attention = _register_torch_operation("scaled_dot_product_attention", module=torch.nn.functional)
Expand Down Expand Up @@ -1461,8 +1560,59 @@ def _pad_prim_impl(
ltorch.log_softmax_backward, checker=_always_executable, execution_transform=_log_softmax_backward_transform
)
_register_implementation(ltorch.max_pool1d, max_pool1d, checker=_always_executable)
_register_implementation(ltorch.max_pool2d, max_pool2d, checker=_always_executable)
_register_implementation(ltorch.max_pool3d, max_pool3d, checker=_always_executable)


def max_pool2d_bwd_wrapper(
a: TensorProxy,
/,
kernel_size: int | Sequence[int],
stride: int | Sequence[int] | None = None,
padding: int | Sequence[int] = 0,
dilation: int | Sequence[int] = 1,
return_indices: bool = False,
ceil_mode: bool = False,
) -> tuple[TensorProxy, TensorProxy] | TensorProxy:
primals = max_pool2d_with_indices(a, kernel_size, stride, padding, dilation, ceil_mode)

grad = get_grad(primals[0])
grad_a = max_pool2d_with_indices_backward(grad, a, kernel_size, stride, padding, dilation, ceil_mode, primals[1])
put_grad(a, grad_a)

if return_indices:
return primals
else:
return primals[0]


def max_pool3d_bwd_wrapper(
a: TensorProxy,
/,
kernel_size: int | Sequence[int],
stride: int | Sequence[int] | None = None,
padding: int | Sequence[int] = 0,
dilation: int | Sequence[int] = 1,
return_indices: bool = False,
ceil_mode: bool = False,
) -> tuple[TensorProxy, TensorProxy] | TensorProxy:
primals = max_pool3d_with_indices(a, kernel_size, stride, padding, dilation, ceil_mode)
jjsjann123 marked this conversation as resolved.
Show resolved Hide resolved

grad = get_grad(primals[0])
grad_a = max_pool3d_with_indices_backward(grad, a, kernel_size, stride, padding, dilation, ceil_mode, primals[1])
put_grad(a, grad_a)

if return_indices:
return primals
else:
return primals[0]


# ltorch.max_pool2d/3d decomposition uses convolution, which has performance issue running through torchex. We added grad_transform that keep both forward and backward max_pool as a torch composite operator, which avoids the performance issue. For details: https://github.com/Lightning-AI/lightning-thunder/issues/164. Aten doesn't have explicit functions for max_pool1d fwd/bwd. So the specialization is only done for 2d/3d case.
ex.register_implementation(
ltorch.max_pool2d, max_pool2d, checker=_always_executable, grad_transform=max_pool2d_bwd_wrapper
jjsjann123 marked this conversation as resolved.
Show resolved Hide resolved
)
ex.register_implementation(
ltorch.max_pool3d, max_pool3d, checker=_always_executable, grad_transform=max_pool3d_bwd_wrapper
)
_register_implementation(ltorch.nll_loss, checker=_always_executable, execution_transform=_nll_loss_transform)
nll_loss_backward = ex.register_operator(
"torch_nll_loss_backward_impl", meta=ltorch.nll_loss_backward, fn=_nll_loss_backward_impl
Expand Down
Loading