From 1756c03b27d5d0f5cd3ae8d8e220d22dcd47fa9d Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 10 Apr 2024 17:30:23 -0700 Subject: [PATCH 01/21] special grad_transform for max_poolxd --- thunder/executors/torchex.py | 53 ++++++++++++++++++++++++++++++++++-- 1 file changed, 51 insertions(+), 2 deletions(-) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 6345e2780f..ef80dcbfe9 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1217,6 +1217,10 @@ def _take_along_axis_prim_transform(a: TensorProxy, /, index: TensorProxy, dim: max_pool1d = _register_torch_operation("max_pool1d", module=torch.nn.functional) max_pool2d = _register_torch_operation("max_pool2d", module=torch.nn.functional) max_pool3d = _register_torch_operation("max_pool3d", module=torch.nn.functional) +max_pool2d_backward = _register_torch_operation("max_pool2d_backward", module=torch.ops.aten.max_pool2d_backward.op) +max_pool2d_with_indices_backward = _register_torch_operation("max_pool2d_with_indices_backward", module=torch.ops.aten.max_pool2d_backward) +max_pool3d_backward = _register_torch_operation("max_pool3d_backward", module=torch.ops.aten.max_pool3d_backward.op) +max_pool3d_with_indices_backward = _register_torch_operation("max_pool3d_with_indices_backward", module=torch.ops.aten.max_pool3d_backward) nll_loss = _register_torch_operation("nll_loss", module=torch.nn.functional) pad = _register_torch_operation("pad", module=torch.nn.functional) scaled_dot_product_attention = _register_torch_operation("scaled_dot_product_attention", module=torch.nn.functional) @@ -1461,8 +1465,53 @@ def _pad_prim_impl( ltorch.log_softmax_backward, checker=_always_executable, execution_transform=_log_softmax_backward_transform ) _register_implementation(ltorch.max_pool1d, max_pool1d, checker=_always_executable) -_register_implementation(ltorch.max_pool2d, max_pool2d, checker=_always_executable) -_register_implementation(ltorch.max_pool3d, max_pool3d, checker=_always_executable) + +def max_pool2d_bwd_wrapper( + a: TensorProxy, + /, + kernel_size: int | Sequence[int], + stride: int | Sequence[int] | None = None, + padding: int | Sequence[int] = 0, + dilation: int | Sequence[int] = 1, + return_indices: bool = False, + ceil_mode: bool = False, +): + primals = max_pool2d(a, kernel_size, stride, padding, dilation, return_indices, ceil_mode) + + if return_indices: + grad = get_grad(primals[0]) + grad_a = max_pool2d_with_indices_backward(grad, a, kernel_size, stride, padding, dilation, ceil_mode, primals[1]) + else: + grad = get_grad(primals) + grad_a = max_pool2d_backward(grad, a, kernel_size, stride, padding, dilation, ceil_mode) + put_grad(a, grad_a) + + return primals + +def max_pool3d_bwd_wrapper( + a: TensorProxy, + /, + kernel_size: int | Sequence[int], + stride: int | Sequence[int] | None = None, + padding: int | Sequence[int] = 0, + dilation: int | Sequence[int] = 1, + return_indices: bool = False, + ceil_mode: bool = False, +): + primals = max_pool3d(a, kernel_size, stride, padding, dilation, return_indices, ceil_mode) + + if return_indices: + grad = get_grad(primals[0]) + grad_a = max_pool3d_with_indices_backward(grad, a, kernel_size, stride, padding, dilation, ceil_mode, primals[1]) + else: + grad = get_grad(primals) + grad_a = max_pool3d_backward(grad, a, kernel_size, stride, padding, dilation, ceil_mode) + put_grad(a, grad_a) + + return primals + +ex._register_implementation(ltorch.max_pool2d, max_pool2d, checker=_always_executable, grad_transform=) +ex._register_implementation(ltorch.max_pool3d, max_pool3d, checker=_always_executable, grad_transform=) _register_implementation(ltorch.nll_loss, checker=_always_executable, execution_transform=_nll_loss_transform) nll_loss_backward = ex.register_operator( "torch_nll_loss_backward_impl", meta=ltorch.nll_loss_backward, fn=_nll_loss_backward_impl From 9511603e4a23a48f30259b6fc90cb7f3c4afe7a0 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 10 Apr 2024 17:31:21 -0700 Subject: [PATCH 02/21] forgot the name --- thunder/executors/torchex.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index ef80dcbfe9..5ae9be14e7 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1510,8 +1510,8 @@ def max_pool3d_bwd_wrapper( return primals -ex._register_implementation(ltorch.max_pool2d, max_pool2d, checker=_always_executable, grad_transform=) -ex._register_implementation(ltorch.max_pool3d, max_pool3d, checker=_always_executable, grad_transform=) +ex._register_implementation(ltorch.max_pool2d, max_pool2d, checker=_always_executable, grad_transform=max_pool2d_bwd_wrapper) +ex._register_implementation(ltorch.max_pool3d, max_pool3d, checker=_always_executable, grad_transform=max_pool3d_bwd_wrapper) _register_implementation(ltorch.nll_loss, checker=_always_executable, execution_transform=_nll_loss_transform) nll_loss_backward = ex.register_operator( "torch_nll_loss_backward_impl", meta=ltorch.nll_loss_backward, fn=_nll_loss_backward_impl From f641130a0d80d31a464301423a43a789ebd69610 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 10 Apr 2024 17:54:20 -0700 Subject: [PATCH 03/21] adding torch symbol for max_pool backward --- thunder/torch/__init__.py | 58 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 0cee3f169b..c5106fe199 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -3081,6 +3081,64 @@ def max_pool2d( return _max_pool_helper(2, a, kernel_size, stride, padding, dilation, return_indices, ceil_mode) +@torchsymbol(torch.ops.aten.max_pool2d_backward, id="torch.ops.aten.max_pool2d_backward", is_method=False) +def max_pool2d_backward( + grad: TensorProxy, + a: TensorProxy, + /, + kernel_size: int | Sequence[int], + stride: int | Sequence[int] | None = None, + padding: int | Sequence[int] = 0, + dilation: int | Sequence[int] = 1, + ceil_mode: bool = False, +) -> TensorProxy: + return TensorProxy(like=a) + + +@torchsymbol(torch.ops.aten.max_pool2d_with_indices_backward, id="torch.ops.aten.max_pool2d_with_indices_backward", is_method=False) +def max_pool2d_with_indices_backward( + grad: TensorProxy, + a: TensorProxy, + /, + kernel_size: int | Sequence[int], + stride: int | Sequence[int] | None = None, + padding: int | Sequence[int] = 0, + dilation: int | Sequence[int] = 1, + ceil_mode: bool = False, + result1: TensorProxy, +) -> list[TensorProxy | None]: + return [TensorProxy(like=a), None] + + +@torchsymbol(torch.ops.aten.max_pool3d_backward, id="torch.ops.aten.max_pool3d_backward", is_method=False) +def max_pool3d_backward( + grad: TensorProxy, + a: TensorProxy, + /, + kernel_size: int | Sequence[int], + stride: int | Sequence[int] | None = None, + padding: int | Sequence[int] = 0, + dilation: int | Sequence[int] = 1, + ceil_mode: bool = False, +) -> TensorProxy: + return TensorProxy(like=a) + + +@torchsymbol(torch.ops.aten.max_pool3d_with_indices_backward, id="torch.ops.aten.max_pool3d_with_indices_backward", is_method=False) +def max_pool3d_with_indices_backward( + grad: TensorProxy, + a: TensorProxy, + /, + kernel_size: int | Sequence[int], + stride: int | Sequence[int] | None = None, + padding: int | Sequence[int] = 0, + dilation: int | Sequence[int] = 1, + ceil_mode: bool = False, + result1: TensorProxy, +) -> list[TensorProxy | None]: + return [TensorProxy(like=a), None] + + @torchsymbol(torch.max_pool3d, torch.nn.functional.max_pool3d, id="torch.nn.functional.max_pool3d", is_method=False) def max_pool3d( a: TensorProxy, From 75e284c65c61f9464a8077b7d7c7ebd3c8bc9037 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 10 Apr 2024 17:56:54 -0700 Subject: [PATCH 04/21] fixing signature --- thunder/torch/__init__.py | 44 ++++++++++++++++++--------------------- 1 file changed, 20 insertions(+), 24 deletions(-) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index c5106fe199..c4bbcdbf69 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -3085,12 +3085,11 @@ def max_pool2d( def max_pool2d_backward( grad: TensorProxy, a: TensorProxy, - /, - kernel_size: int | Sequence[int], - stride: int | Sequence[int] | None = None, - padding: int | Sequence[int] = 0, - dilation: int | Sequence[int] = 1, - ceil_mode: bool = False, + kernel_size: int, + stride: int | Sequence[int] | None, + padding: int | Sequence[int], + dilation: int | Sequence[int], + ceil_mode: bool, ) -> TensorProxy: return TensorProxy(like=a) @@ -3099,12 +3098,11 @@ def max_pool2d_backward( def max_pool2d_with_indices_backward( grad: TensorProxy, a: TensorProxy, - /, - kernel_size: int | Sequence[int], - stride: int | Sequence[int] | None = None, - padding: int | Sequence[int] = 0, - dilation: int | Sequence[int] = 1, - ceil_mode: bool = False, + kernel_size: int, + stride: int | Sequence[int] | None, + padding: int | Sequence[int], + dilation: int | Sequence[int], + ceil_mode: bool, result1: TensorProxy, ) -> list[TensorProxy | None]: return [TensorProxy(like=a), None] @@ -3114,12 +3112,11 @@ def max_pool2d_with_indices_backward( def max_pool3d_backward( grad: TensorProxy, a: TensorProxy, - /, - kernel_size: int | Sequence[int], - stride: int | Sequence[int] | None = None, - padding: int | Sequence[int] = 0, - dilation: int | Sequence[int] = 1, - ceil_mode: bool = False, + kernel_size: int, + stride: int | Sequence[int] | None, + padding: int | Sequence[int], + dilation: int | Sequence[int], + ceil_mode: bool, ) -> TensorProxy: return TensorProxy(like=a) @@ -3128,12 +3125,11 @@ def max_pool3d_backward( def max_pool3d_with_indices_backward( grad: TensorProxy, a: TensorProxy, - /, - kernel_size: int | Sequence[int], - stride: int | Sequence[int] | None = None, - padding: int | Sequence[int] = 0, - dilation: int | Sequence[int] = 1, - ceil_mode: bool = False, + kernel_size: int, + stride: int | Sequence[int] | None, + padding: int | Sequence[int], + dilation: int | Sequence[int], + ceil_mode: bool, result1: TensorProxy, ) -> list[TensorProxy | None]: return [TensorProxy(like=a), None] From cc489c6163bdad2f85769dce17ea0e9aa6136ac9 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 10 Apr 2024 18:00:10 -0700 Subject: [PATCH 05/21] removing 3d because of pytorch aten API coverage --- thunder/executors/torchex.py | 26 +------------------------- thunder/torch/__init__.py | 27 --------------------------- 2 files changed, 1 insertion(+), 52 deletions(-) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 5ae9be14e7..adeba4081b 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1219,8 +1219,6 @@ def _take_along_axis_prim_transform(a: TensorProxy, /, index: TensorProxy, dim: max_pool3d = _register_torch_operation("max_pool3d", module=torch.nn.functional) max_pool2d_backward = _register_torch_operation("max_pool2d_backward", module=torch.ops.aten.max_pool2d_backward.op) max_pool2d_with_indices_backward = _register_torch_operation("max_pool2d_with_indices_backward", module=torch.ops.aten.max_pool2d_backward) -max_pool3d_backward = _register_torch_operation("max_pool3d_backward", module=torch.ops.aten.max_pool3d_backward.op) -max_pool3d_with_indices_backward = _register_torch_operation("max_pool3d_with_indices_backward", module=torch.ops.aten.max_pool3d_backward) nll_loss = _register_torch_operation("nll_loss", module=torch.nn.functional) pad = _register_torch_operation("pad", module=torch.nn.functional) scaled_dot_product_attention = _register_torch_operation("scaled_dot_product_attention", module=torch.nn.functional) @@ -1488,30 +1486,8 @@ def max_pool2d_bwd_wrapper( return primals -def max_pool3d_bwd_wrapper( - a: TensorProxy, - /, - kernel_size: int | Sequence[int], - stride: int | Sequence[int] | None = None, - padding: int | Sequence[int] = 0, - dilation: int | Sequence[int] = 1, - return_indices: bool = False, - ceil_mode: bool = False, -): - primals = max_pool3d(a, kernel_size, stride, padding, dilation, return_indices, ceil_mode) - - if return_indices: - grad = get_grad(primals[0]) - grad_a = max_pool3d_with_indices_backward(grad, a, kernel_size, stride, padding, dilation, ceil_mode, primals[1]) - else: - grad = get_grad(primals) - grad_a = max_pool3d_backward(grad, a, kernel_size, stride, padding, dilation, ceil_mode) - put_grad(a, grad_a) - - return primals - ex._register_implementation(ltorch.max_pool2d, max_pool2d, checker=_always_executable, grad_transform=max_pool2d_bwd_wrapper) -ex._register_implementation(ltorch.max_pool3d, max_pool3d, checker=_always_executable, grad_transform=max_pool3d_bwd_wrapper) +_register_implementation(ltorch.max_pool3d, max_pool3d, checker=_always_executable) _register_implementation(ltorch.nll_loss, checker=_always_executable, execution_transform=_nll_loss_transform) nll_loss_backward = ex.register_operator( "torch_nll_loss_backward_impl", meta=ltorch.nll_loss_backward, fn=_nll_loss_backward_impl diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index c4bbcdbf69..ae1b2b3d58 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -3108,33 +3108,6 @@ def max_pool2d_with_indices_backward( return [TensorProxy(like=a), None] -@torchsymbol(torch.ops.aten.max_pool3d_backward, id="torch.ops.aten.max_pool3d_backward", is_method=False) -def max_pool3d_backward( - grad: TensorProxy, - a: TensorProxy, - kernel_size: int, - stride: int | Sequence[int] | None, - padding: int | Sequence[int], - dilation: int | Sequence[int], - ceil_mode: bool, -) -> TensorProxy: - return TensorProxy(like=a) - - -@torchsymbol(torch.ops.aten.max_pool3d_with_indices_backward, id="torch.ops.aten.max_pool3d_with_indices_backward", is_method=False) -def max_pool3d_with_indices_backward( - grad: TensorProxy, - a: TensorProxy, - kernel_size: int, - stride: int | Sequence[int] | None, - padding: int | Sequence[int], - dilation: int | Sequence[int], - ceil_mode: bool, - result1: TensorProxy, -) -> list[TensorProxy | None]: - return [TensorProxy(like=a), None] - - @torchsymbol(torch.max_pool3d, torch.nn.functional.max_pool3d, id="torch.nn.functional.max_pool3d", is_method=False) def max_pool3d( a: TensorProxy, From 486060a7bd0d72ab4a84549ee7a283eb79b3c0c7 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 10 Apr 2024 18:33:48 -0700 Subject: [PATCH 06/21] fixing max_pool2d with indices --- thunder/executors/torchex.py | 25 ++++++++++++++----------- thunder/torch/__init__.py | 15 +-------------- 2 files changed, 15 insertions(+), 25 deletions(-) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index adeba4081b..92503a1ba7 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -37,6 +37,11 @@ from thunder.extend import OperatorExecutor, register_executor, add_always_executor +from thunder.core.transforms import ( + get_grad, + put_grad, +) + ex = OperatorExecutor("torch", version=torch.__version__) register_executor(ex) add_always_executor(ex) @@ -1217,8 +1222,7 @@ def _take_along_axis_prim_transform(a: TensorProxy, /, index: TensorProxy, dim: max_pool1d = _register_torch_operation("max_pool1d", module=torch.nn.functional) max_pool2d = _register_torch_operation("max_pool2d", module=torch.nn.functional) max_pool3d = _register_torch_operation("max_pool3d", module=torch.nn.functional) -max_pool2d_backward = _register_torch_operation("max_pool2d_backward", module=torch.ops.aten.max_pool2d_backward.op) -max_pool2d_with_indices_backward = _register_torch_operation("max_pool2d_with_indices_backward", module=torch.ops.aten.max_pool2d_backward) +max_pool2d_with_indices_backward = _register_torch_operation("torch.ops.aten.max_pool2d_with_indices_backward", like=ltorch.max_pool2d_with_indices_backward) nll_loss = _register_torch_operation("nll_loss", module=torch.nn.functional) pad = _register_torch_operation("pad", module=torch.nn.functional) scaled_dot_product_attention = _register_torch_operation("scaled_dot_product_attention", module=torch.nn.functional) @@ -1474,19 +1478,18 @@ def max_pool2d_bwd_wrapper( return_indices: bool = False, ceil_mode: bool = False, ): - primals = max_pool2d(a, kernel_size, stride, padding, dilation, return_indices, ceil_mode) + primals = max_pool2d(a, kernel_size, stride, padding, dilation, True, ceil_mode) - if return_indices: - grad = get_grad(primals[0]) - grad_a = max_pool2d_with_indices_backward(grad, a, kernel_size, stride, padding, dilation, ceil_mode, primals[1]) - else: - grad = get_grad(primals) - grad_a = max_pool2d_backward(grad, a, kernel_size, stride, padding, dilation, ceil_mode) + grad = get_grad(primals[0]) + grad_a = max_pool2d_with_indices_backward(grad, a, kernel_size, stride, padding, dilation, ceil_mode, primals[1]) put_grad(a, grad_a) - return primals + if return_indices: + return primals + else: + return primals[0] -ex._register_implementation(ltorch.max_pool2d, max_pool2d, checker=_always_executable, grad_transform=max_pool2d_bwd_wrapper) +ex.register_implementation(ltorch.max_pool2d, max_pool2d, checker=_always_executable, grad_transform=max_pool2d_bwd_wrapper) _register_implementation(ltorch.max_pool3d, max_pool3d, checker=_always_executable) _register_implementation(ltorch.nll_loss, checker=_always_executable, execution_transform=_nll_loss_transform) nll_loss_backward = ex.register_operator( diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index ae1b2b3d58..fe0dee4012 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -3081,20 +3081,7 @@ def max_pool2d( return _max_pool_helper(2, a, kernel_size, stride, padding, dilation, return_indices, ceil_mode) -@torchsymbol(torch.ops.aten.max_pool2d_backward, id="torch.ops.aten.max_pool2d_backward", is_method=False) -def max_pool2d_backward( - grad: TensorProxy, - a: TensorProxy, - kernel_size: int, - stride: int | Sequence[int] | None, - padding: int | Sequence[int], - dilation: int | Sequence[int], - ceil_mode: bool, -) -> TensorProxy: - return TensorProxy(like=a) - - -@torchsymbol(torch.ops.aten.max_pool2d_with_indices_backward, id="torch.ops.aten.max_pool2d_with_indices_backward", is_method=False) +@torchsymbol(torch.ops.aten.max_pool2d_with_indices_backward, id="max_pool2d_with_indices_backward", is_method=False) def max_pool2d_with_indices_backward( grad: TensorProxy, a: TensorProxy, From d252018786cfbbe4bde3d6c03469bfc207568341 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 11 Apr 2024 17:52:34 +0000 Subject: [PATCH 07/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/executors/torchex.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 92503a1ba7..9c5ce68d81 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1222,7 +1222,9 @@ def _take_along_axis_prim_transform(a: TensorProxy, /, index: TensorProxy, dim: max_pool1d = _register_torch_operation("max_pool1d", module=torch.nn.functional) max_pool2d = _register_torch_operation("max_pool2d", module=torch.nn.functional) max_pool3d = _register_torch_operation("max_pool3d", module=torch.nn.functional) -max_pool2d_with_indices_backward = _register_torch_operation("torch.ops.aten.max_pool2d_with_indices_backward", like=ltorch.max_pool2d_with_indices_backward) +max_pool2d_with_indices_backward = _register_torch_operation( + "torch.ops.aten.max_pool2d_with_indices_backward", like=ltorch.max_pool2d_with_indices_backward +) nll_loss = _register_torch_operation("nll_loss", module=torch.nn.functional) pad = _register_torch_operation("pad", module=torch.nn.functional) scaled_dot_product_attention = _register_torch_operation("scaled_dot_product_attention", module=torch.nn.functional) @@ -1468,6 +1470,7 @@ def _pad_prim_impl( ) _register_implementation(ltorch.max_pool1d, max_pool1d, checker=_always_executable) + def max_pool2d_bwd_wrapper( a: TensorProxy, /, @@ -1489,7 +1492,10 @@ def max_pool2d_bwd_wrapper( else: return primals[0] -ex.register_implementation(ltorch.max_pool2d, max_pool2d, checker=_always_executable, grad_transform=max_pool2d_bwd_wrapper) + +ex.register_implementation( + ltorch.max_pool2d, max_pool2d, checker=_always_executable, grad_transform=max_pool2d_bwd_wrapper +) _register_implementation(ltorch.max_pool3d, max_pool3d, checker=_always_executable) _register_implementation(ltorch.nll_loss, checker=_always_executable, execution_transform=_nll_loss_transform) nll_loss_backward = ex.register_operator( From d6140359e65dc700e1f0675509d8184612879e68 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 11 Apr 2024 14:04:47 -0700 Subject: [PATCH 08/21] adding torch operator max_pool2d_with_indices --- thunder/executors/torchex.py | 71 +++++++++++++++++++++++++++++++++++- 1 file changed, 70 insertions(+), 1 deletion(-) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 92503a1ba7..dbc37bcd0f 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1222,7 +1222,76 @@ def _take_along_axis_prim_transform(a: TensorProxy, /, index: TensorProxy, dim: max_pool1d = _register_torch_operation("max_pool1d", module=torch.nn.functional) max_pool2d = _register_torch_operation("max_pool2d", module=torch.nn.functional) max_pool3d = _register_torch_operation("max_pool3d", module=torch.nn.functional) + +def _max_pool_with_indices_helper( + ndim: int, + a: TensorProxy, + /, + kernel_size: int | Sequence[int], + stride: int | Sequence[int] | None, + padding: int | Sequence[int], + dilation: int | Sequence[int], + ceil_mode: bool, +) -> [TensorProxy, TensorProxy]: + def div_rtn(x, y): + q = x / y; + r = x % y; + if r != 0 and (r < 0) != (y < 0): + q -= 1 + return q + + def pooling_output_shape(in_, kernel_, pad_, stride_, dilation_, ceil_mode_:bool): + out_size = div_rtn(in_ + 2 * pad_ - dilation_ * (kernel_ - 1) - 1 + (stride - 1 if ceil_mode else 0), stride) + 1 + if ceil_mode and (out_size - 1) * stride >= in_ + pad_: + out_size -= 1 + return out_size + + def get_maybe_ith_entry(seq : Sequence[int], i : int, default : int = None): + if seq is None: + return default + + if len(seq) == 1: + return seq[0] + else: + return seq[i] + + out_sizes = [] + for i in range(ndim): + in_ = a.shape[i - ndim] # i - ndim is the i-th spatial dimension + kernel_ = get_maybe_ith_entry(kernel_size, i) + stride_ = get_maybe_ith_entry(stride, i, kernel_) + pad_ = get_maybe_ith_entry(padding, i) + dilation_ = get_maybe_ith_entry(dilation, i) + out_sizes.append(pooling_output_shape(in_, kernel_, pad_, stride_, dilation_, ceil_mode)) + + return TensorProxy(like=a, shape=out_sizes), TensorProxy(like=a, shape=out_sizes) + +def max_pool2d_with_indices_meta( + a: TensorProxy, + /, + kernel_size: int | Sequence[int], + stride: int | Sequence[int] | None = None, + padding: int | Sequence[int] = 0, + dilation: int | Sequence[int] = 1, + ceil_mode: bool = False, +) -> [TensorProxy, TensorProxy]: + return _max_pool_with_indices_helper(2, a, kernel_size, stride, padding, dilation, ceil_mode) + +def _max_pool2d_with_indices( + a: TensorLike, + /, + kernel_size: int | Sequence[int], + stride: int | Sequence[int] | None = None, + padding: int | Sequence[int] = 0, + dilation: int | Sequence[int] = 1, + ceil_mode: bool = False, +) -> [TensorLike, TensorLike]: + return torch.ops.aten.max_pool2d_with_indices(a, kernel_size, stride, padding, dilation, ceil_mode) + + +max_pool2d_with_indices = ex.register_operator("max_pool2d_with_indices", meta=max_pool2d_with_indices_meta, fn=_max_pool2d_with_indices) max_pool2d_with_indices_backward = _register_torch_operation("torch.ops.aten.max_pool2d_with_indices_backward", like=ltorch.max_pool2d_with_indices_backward) + nll_loss = _register_torch_operation("nll_loss", module=torch.nn.functional) pad = _register_torch_operation("pad", module=torch.nn.functional) scaled_dot_product_attention = _register_torch_operation("scaled_dot_product_attention", module=torch.nn.functional) @@ -1478,7 +1547,7 @@ def max_pool2d_bwd_wrapper( return_indices: bool = False, ceil_mode: bool = False, ): - primals = max_pool2d(a, kernel_size, stride, padding, dilation, True, ceil_mode) + primals = max_pool2d_with_indices(a, kernel_size, stride, padding, dilation, ceil_mode) grad = get_grad(primals[0]) grad_a = max_pool2d_with_indices_backward(grad, a, kernel_size, stride, padding, dilation, ceil_mode, primals[1]) From c9955117dd3eb593e95b97a4d485545690778a89 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 11 Apr 2024 21:07:11 +0000 Subject: [PATCH 09/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/executors/torchex.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index cb05d7fe92..5692ec7920 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1223,6 +1223,7 @@ def _take_along_axis_prim_transform(a: TensorProxy, /, index: TensorProxy, dim: max_pool2d = _register_torch_operation("max_pool2d", module=torch.nn.functional) max_pool3d = _register_torch_operation("max_pool3d", module=torch.nn.functional) + def _max_pool_with_indices_helper( ndim: int, a: TensorProxy, @@ -1234,19 +1235,21 @@ def _max_pool_with_indices_helper( ceil_mode: bool, ) -> [TensorProxy, TensorProxy]: def div_rtn(x, y): - q = x / y; - r = x % y; + q = x / y + r = x % y if r != 0 and (r < 0) != (y < 0): q -= 1 return q - def pooling_output_shape(in_, kernel_, pad_, stride_, dilation_, ceil_mode_:bool): - out_size = div_rtn(in_ + 2 * pad_ - dilation_ * (kernel_ - 1) - 1 + (stride - 1 if ceil_mode else 0), stride) + 1 + def pooling_output_shape(in_, kernel_, pad_, stride_, dilation_, ceil_mode_: bool): + out_size = ( + div_rtn(in_ + 2 * pad_ - dilation_ * (kernel_ - 1) - 1 + (stride - 1 if ceil_mode else 0), stride) + 1 + ) if ceil_mode and (out_size - 1) * stride >= in_ + pad_: out_size -= 1 return out_size - def get_maybe_ith_entry(seq : Sequence[int], i : int, default : int = None): + def get_maybe_ith_entry(seq: Sequence[int], i: int, default: int = None): if seq is None: return default @@ -1257,7 +1260,7 @@ def get_maybe_ith_entry(seq : Sequence[int], i : int, default : int = None): out_sizes = [] for i in range(ndim): - in_ = a.shape[i - ndim] # i - ndim is the i-th spatial dimension + in_ = a.shape[i - ndim] # i - ndim is the i-th spatial dimension kernel_ = get_maybe_ith_entry(kernel_size, i) stride_ = get_maybe_ith_entry(stride, i, kernel_) pad_ = get_maybe_ith_entry(padding, i) @@ -1266,6 +1269,7 @@ def get_maybe_ith_entry(seq : Sequence[int], i : int, default : int = None): return TensorProxy(like=a, shape=out_sizes), TensorProxy(like=a, shape=out_sizes) + def max_pool2d_with_indices_meta( a: TensorProxy, /, @@ -1277,6 +1281,7 @@ def max_pool2d_with_indices_meta( ) -> [TensorProxy, TensorProxy]: return _max_pool_with_indices_helper(2, a, kernel_size, stride, padding, dilation, ceil_mode) + def _max_pool2d_with_indices( a: TensorLike, /, @@ -1289,7 +1294,9 @@ def _max_pool2d_with_indices( return torch.ops.aten.max_pool2d_with_indices(a, kernel_size, stride, padding, dilation, ceil_mode) -max_pool2d_with_indices = ex.register_operator("max_pool2d_with_indices", meta=max_pool2d_with_indices_meta, fn=_max_pool2d_with_indices) +max_pool2d_with_indices = ex.register_operator( + "max_pool2d_with_indices", meta=max_pool2d_with_indices_meta, fn=_max_pool2d_with_indices +) max_pool2d_with_indices_backward = _register_torch_operation( "torch.ops.aten.max_pool2d_with_indices_backward", like=ltorch.max_pool2d_with_indices_backward ) From 82f56e97b72605f30cbcae3c8665bc879bb798b2 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 11 Apr 2024 17:05:00 -0700 Subject: [PATCH 10/21] patch backward operator --- thunder/executors/torchex.py | 37 +++++++++++++++++++++++++++--------- thunder/torch/__init__.py | 2 +- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 5692ec7920..4a098ad4f3 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1281,24 +1281,43 @@ def max_pool2d_with_indices_meta( ) -> [TensorProxy, TensorProxy]: return _max_pool_with_indices_helper(2, a, kernel_size, stride, padding, dilation, ceil_mode) +<<<<<<< Updated upstream def _max_pool2d_with_indices( a: TensorLike, /, +======= +def max_pool2d_with_indices_backward_meta( + grad: TensorProxy, + a: TensorProxy, +>>>>>>> Stashed changes kernel_size: int | Sequence[int], - stride: int | Sequence[int] | None = None, - padding: int | Sequence[int] = 0, - dilation: int | Sequence[int] = 1, - ceil_mode: bool = False, -) -> [TensorLike, TensorLike]: - return torch.ops.aten.max_pool2d_with_indices(a, kernel_size, stride, padding, dilation, ceil_mode) + stride: int | Sequence[int] | None, + padding: int | Sequence[int], + dilation: int | Sequence[int], + ceil_mode: bool, + result1: TensorProxy, +) -> list[TensorProxy | None]: + return [TensorProxy(like=a), None] + +#def _max_pool2d_with_indices( +# a: TensorLike, +# /, +# kernel_size: int | Sequence[int], +# stride: int | Sequence[int] | None = None, +# padding: int | Sequence[int] = 0, +# dilation: int | Sequence[int] = 1, +# ceil_mode: bool = False, +#) -> [TensorLike, TensorLike]: +# return torch.ops.aten.max_pool2d_with_indices(a, kernel_size, stride, padding, dilation, ceil_mode) + max_pool2d_with_indices = ex.register_operator( - "max_pool2d_with_indices", meta=max_pool2d_with_indices_meta, fn=_max_pool2d_with_indices + "max_pool2d_with_indices", meta=max_pool2d_with_indices_meta, fn=torch.ops.aten.max_pool2d_with_indices ) -max_pool2d_with_indices_backward = _register_torch_operation( - "torch.ops.aten.max_pool2d_with_indices_backward", like=ltorch.max_pool2d_with_indices_backward +max_pool2d_with_indices_backward = ex.register_torch_operation( + "max_pool2d_with_indices_backward", meta=max_pool2d_with_indices_backward_meta, fn=torch.ops.aten.max_pool2d_with_indices_backward ) nll_loss = _register_torch_operation("nll_loss", module=torch.nn.functional) pad = _register_torch_operation("pad", module=torch.nn.functional) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index fe0dee4012..0e07deec9e 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -3081,7 +3081,7 @@ def max_pool2d( return _max_pool_helper(2, a, kernel_size, stride, padding, dilation, return_indices, ceil_mode) -@torchsymbol(torch.ops.aten.max_pool2d_with_indices_backward, id="max_pool2d_with_indices_backward", is_method=False) +@torchsymbol(torch.max_pool2d_with_indices_backward, id="max_pool2d_with_indices_backward", is_method=False) def max_pool2d_with_indices_backward( grad: TensorProxy, a: TensorProxy, From 22efa403a3793eac7e3e484927180c88196d90bc Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 11 Apr 2024 17:07:33 -0700 Subject: [PATCH 11/21] patch --- thunder/executors/torchex.py | 7 ------- thunder/torch/__init__.py | 14 -------------- 2 files changed, 21 deletions(-) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 4a098ad4f3..e5912aac23 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1281,16 +1281,9 @@ def max_pool2d_with_indices_meta( ) -> [TensorProxy, TensorProxy]: return _max_pool_with_indices_helper(2, a, kernel_size, stride, padding, dilation, ceil_mode) -<<<<<<< Updated upstream - -def _max_pool2d_with_indices( - a: TensorLike, - /, -======= def max_pool2d_with_indices_backward_meta( grad: TensorProxy, a: TensorProxy, ->>>>>>> Stashed changes kernel_size: int | Sequence[int], stride: int | Sequence[int] | None, padding: int | Sequence[int], diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 0e07deec9e..0cee3f169b 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -3081,20 +3081,6 @@ def max_pool2d( return _max_pool_helper(2, a, kernel_size, stride, padding, dilation, return_indices, ceil_mode) -@torchsymbol(torch.max_pool2d_with_indices_backward, id="max_pool2d_with_indices_backward", is_method=False) -def max_pool2d_with_indices_backward( - grad: TensorProxy, - a: TensorProxy, - kernel_size: int, - stride: int | Sequence[int] | None, - padding: int | Sequence[int], - dilation: int | Sequence[int], - ceil_mode: bool, - result1: TensorProxy, -) -> list[TensorProxy | None]: - return [TensorProxy(like=a), None] - - @torchsymbol(torch.max_pool3d, torch.nn.functional.max_pool3d, id="torch.nn.functional.max_pool3d", is_method=False) def max_pool3d( a: TensorProxy, From 254fcb0908f51d3e9065221a309138323d39be6e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 12 Apr 2024 00:08:18 +0000 Subject: [PATCH 12/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/executors/torchex.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index e5912aac23..85a8835898 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1281,6 +1281,7 @@ def max_pool2d_with_indices_meta( ) -> [TensorProxy, TensorProxy]: return _max_pool_with_indices_helper(2, a, kernel_size, stride, padding, dilation, ceil_mode) + def max_pool2d_with_indices_backward_meta( grad: TensorProxy, a: TensorProxy, @@ -1293,7 +1294,8 @@ def max_pool2d_with_indices_backward_meta( ) -> list[TensorProxy | None]: return [TensorProxy(like=a), None] -#def _max_pool2d_with_indices( + +# def _max_pool2d_with_indices( # a: TensorLike, # /, # kernel_size: int | Sequence[int], @@ -1301,16 +1303,17 @@ def max_pool2d_with_indices_backward_meta( # padding: int | Sequence[int] = 0, # dilation: int | Sequence[int] = 1, # ceil_mode: bool = False, -#) -> [TensorLike, TensorLike]: +# ) -> [TensorLike, TensorLike]: # return torch.ops.aten.max_pool2d_with_indices(a, kernel_size, stride, padding, dilation, ceil_mode) - max_pool2d_with_indices = ex.register_operator( "max_pool2d_with_indices", meta=max_pool2d_with_indices_meta, fn=torch.ops.aten.max_pool2d_with_indices ) max_pool2d_with_indices_backward = ex.register_torch_operation( - "max_pool2d_with_indices_backward", meta=max_pool2d_with_indices_backward_meta, fn=torch.ops.aten.max_pool2d_with_indices_backward + "max_pool2d_with_indices_backward", + meta=max_pool2d_with_indices_backward_meta, + fn=torch.ops.aten.max_pool2d_with_indices_backward, ) nll_loss = _register_torch_operation("nll_loss", module=torch.nn.functional) pad = _register_torch_operation("pad", module=torch.nn.functional) From 0e9c441b4de33bc647b6b0c49dc3d1b81adefe74 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 11 Apr 2024 17:14:44 -0700 Subject: [PATCH 13/21] fixing logic --- thunder/executors/torchex.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 85a8835898..bd64267748 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1235,7 +1235,7 @@ def _max_pool_with_indices_helper( ceil_mode: bool, ) -> [TensorProxy, TensorProxy]: def div_rtn(x, y): - q = x / y + q = x // y r = x % y if r != 0 and (r < 0) != (y < 0): q -= 1 @@ -1253,6 +1253,9 @@ def get_maybe_ith_entry(seq: Sequence[int], i: int, default: int = None): if seq is None: return default + if not isinstance(seq, Sequence): + return seq + if len(seq) == 1: return seq[0] else: @@ -1574,7 +1577,7 @@ def max_pool2d_bwd_wrapper( primals = max_pool2d_with_indices(a, kernel_size, stride, padding, dilation, ceil_mode) grad = get_grad(primals[0]) - grad_a = max_pool2d_with_indices_backward(grad, a, kernel_size, stride, padding, dilation, ceil_mode, primals[1]) + grad_a, _ = max_pool2d_with_indices_backward(grad, a, kernel_size, stride, padding, dilation, ceil_mode, primals[1]) put_grad(a, grad_a) if return_indices: From d8647bbc374ec13b0e68ff0361ddaff17aaf0791 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 11 Apr 2024 17:32:01 -0700 Subject: [PATCH 14/21] functionally correct now at least --- thunder/executors/torchex.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index bd64267748..61278a5c9d 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1294,8 +1294,8 @@ def max_pool2d_with_indices_backward_meta( dilation: int | Sequence[int], ceil_mode: bool, result1: TensorProxy, -) -> list[TensorProxy | None]: - return [TensorProxy(like=a), None] +) -> TensorProxy: + return TensorProxy(like=a) # def _max_pool2d_with_indices( @@ -1313,7 +1313,7 @@ def max_pool2d_with_indices_backward_meta( max_pool2d_with_indices = ex.register_operator( "max_pool2d_with_indices", meta=max_pool2d_with_indices_meta, fn=torch.ops.aten.max_pool2d_with_indices ) -max_pool2d_with_indices_backward = ex.register_torch_operation( +max_pool2d_with_indices_backward = ex.register_operator( "max_pool2d_with_indices_backward", meta=max_pool2d_with_indices_backward_meta, fn=torch.ops.aten.max_pool2d_with_indices_backward, @@ -1577,7 +1577,7 @@ def max_pool2d_bwd_wrapper( primals = max_pool2d_with_indices(a, kernel_size, stride, padding, dilation, ceil_mode) grad = get_grad(primals[0]) - grad_a, _ = max_pool2d_with_indices_backward(grad, a, kernel_size, stride, padding, dilation, ceil_mode, primals[1]) + grad_a = max_pool2d_with_indices_backward(grad, a, kernel_size, stride, padding, dilation, ceil_mode, primals[1]) put_grad(a, grad_a) if return_indices: From 7a06c079a8f926f6c5992b569b2e6633e9aa4429 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 11 Apr 2024 22:11:13 -0700 Subject: [PATCH 15/21] refactor to support max_pool3d as well --- thunder/executors/torchex.py | 67 ++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 37 deletions(-) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 61278a5c9d..b6cb9f83a3 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1229,10 +1229,10 @@ def _max_pool_with_indices_helper( a: TensorProxy, /, kernel_size: int | Sequence[int], - stride: int | Sequence[int] | None, - padding: int | Sequence[int], - dilation: int | Sequence[int], - ceil_mode: bool, + stride: int | Sequence[int] | None = None, + padding: int | Sequence[int] = 0, + dilation: int | Sequence[int] = 1, + ceil_mode: bool = False, ) -> [TensorProxy, TensorProxy]: def div_rtn(x, y): q = x // y @@ -1272,20 +1272,7 @@ def get_maybe_ith_entry(seq: Sequence[int], i: int, default: int = None): return TensorProxy(like=a, shape=out_sizes), TensorProxy(like=a, shape=out_sizes) - -def max_pool2d_with_indices_meta( - a: TensorProxy, - /, - kernel_size: int | Sequence[int], - stride: int | Sequence[int] | None = None, - padding: int | Sequence[int] = 0, - dilation: int | Sequence[int] = 1, - ceil_mode: bool = False, -) -> [TensorProxy, TensorProxy]: - return _max_pool_with_indices_helper(2, a, kernel_size, stride, padding, dilation, ceil_mode) - - -def max_pool2d_with_indices_backward_meta( +def max_pool_with_indices_backward_meta( grad: TensorProxy, a: TensorProxy, kernel_size: int | Sequence[int], @@ -1297,27 +1284,28 @@ def max_pool2d_with_indices_backward_meta( ) -> TensorProxy: return TensorProxy(like=a) - -# def _max_pool2d_with_indices( -# a: TensorLike, -# /, -# kernel_size: int | Sequence[int], -# stride: int | Sequence[int] | None = None, -# padding: int | Sequence[int] = 0, -# dilation: int | Sequence[int] = 1, -# ceil_mode: bool = False, -# ) -> [TensorLike, TensorLike]: -# return torch.ops.aten.max_pool2d_with_indices(a, kernel_size, stride, padding, dilation, ceil_mode) - +max_pool2d_with_indices_meta = partial(_max_pool_with_indices_helper, 2) max_pool2d_with_indices = ex.register_operator( "max_pool2d_with_indices", meta=max_pool2d_with_indices_meta, fn=torch.ops.aten.max_pool2d_with_indices ) max_pool2d_with_indices_backward = ex.register_operator( "max_pool2d_with_indices_backward", - meta=max_pool2d_with_indices_backward_meta, + meta=max_pool_with_indices_backward_meta, fn=torch.ops.aten.max_pool2d_with_indices_backward, ) + +max_pool3d_with_indices_meta = partial(_max_pool_with_indices_helper, 3) + +max_pool3d_with_indices = ex.register_operator( + "max_pool3d_with_indices", meta=max_pool3d_with_indices_meta, fn=torch.ops.aten.max_pool3d_with_indices +) +max_pool3d_with_indices_backward = ex.register_operator( + "max_pool3d_with_indices_backward", + meta=max_pool_with_indices_backward_meta, + fn=torch.ops.aten.max_pool3d_with_indices_backward, +) + nll_loss = _register_torch_operation("nll_loss", module=torch.nn.functional) pad = _register_torch_operation("pad", module=torch.nn.functional) scaled_dot_product_attention = _register_torch_operation("scaled_dot_product_attention", module=torch.nn.functional) @@ -1564,7 +1552,9 @@ def _pad_prim_impl( _register_implementation(ltorch.max_pool1d, max_pool1d, checker=_always_executable) -def max_pool2d_bwd_wrapper( +def max_pool_bwd_wrapper( + fwd_fn, + bwd_fn, a: TensorProxy, /, kernel_size: int | Sequence[int], @@ -1574,10 +1564,10 @@ def max_pool2d_bwd_wrapper( return_indices: bool = False, ceil_mode: bool = False, ): - primals = max_pool2d_with_indices(a, kernel_size, stride, padding, dilation, ceil_mode) + primals = fwd_fn(a, kernel_size, stride, padding, dilation, ceil_mode) grad = get_grad(primals[0]) - grad_a = max_pool2d_with_indices_backward(grad, a, kernel_size, stride, padding, dilation, ceil_mode, primals[1]) + grad_a = bwd_fn(grad, a, kernel_size, stride, padding, dilation, ceil_mode, primals[1]) put_grad(a, grad_a) if return_indices: @@ -1585,11 +1575,14 @@ def max_pool2d_bwd_wrapper( else: return primals[0] - +max_pool2d_bwd_wrapper = partial(max_pool_bwd_wrapper, max_pool2d_with_indices, max_pool2d_with_indices_backward) +ex.register_implementation( + ltorch.max_pool2d, max_pool2d, checker=_always_executable, grad_transform=max_pool_bwd_wrapper +) +max_pool3d_bwd_wrapper = partial(max_pool_bwd_wrapper, max_pool3d_with_indices, max_pool3d_with_indices_backward) ex.register_implementation( - ltorch.max_pool2d, max_pool2d, checker=_always_executable, grad_transform=max_pool2d_bwd_wrapper + ltorch.max_pool3d, max_pool3d, checker=_always_executable, grad_transform=max_pool_bwd_wrapper ) -_register_implementation(ltorch.max_pool3d, max_pool3d, checker=_always_executable) _register_implementation(ltorch.nll_loss, checker=_always_executable, execution_transform=_nll_loss_transform) nll_loss_backward = ex.register_operator( "torch_nll_loss_backward_impl", meta=ltorch.nll_loss_backward, fn=_nll_loss_backward_impl From ef7ae3a5a103fa2dc970570ab94f913ee7e2cbb1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 12 Apr 2024 05:12:25 +0000 Subject: [PATCH 16/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/executors/torchex.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index b6cb9f83a3..a9c12fd266 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1272,6 +1272,7 @@ def get_maybe_ith_entry(seq: Sequence[int], i: int, default: int = None): return TensorProxy(like=a, shape=out_sizes), TensorProxy(like=a, shape=out_sizes) + def max_pool_with_indices_backward_meta( grad: TensorProxy, a: TensorProxy, @@ -1284,6 +1285,7 @@ def max_pool_with_indices_backward_meta( ) -> TensorProxy: return TensorProxy(like=a) + max_pool2d_with_indices_meta = partial(_max_pool_with_indices_helper, 2) max_pool2d_with_indices = ex.register_operator( @@ -1575,6 +1577,7 @@ def max_pool_bwd_wrapper( else: return primals[0] + max_pool2d_bwd_wrapper = partial(max_pool_bwd_wrapper, max_pool2d_with_indices, max_pool2d_with_indices_backward) ex.register_implementation( ltorch.max_pool2d, max_pool2d, checker=_always_executable, grad_transform=max_pool_bwd_wrapper From 351469d50c4f23ee8db2392bc1aaa0ae83aa2fbc Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 11 Apr 2024 22:27:40 -0700 Subject: [PATCH 17/21] partial can't be used in grad_transform --- thunder/executors/torchex.py | 35 ++++++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index a9c12fd266..7f82a05d56 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1554,9 +1554,7 @@ def _pad_prim_impl( _register_implementation(ltorch.max_pool1d, max_pool1d, checker=_always_executable) -def max_pool_bwd_wrapper( - fwd_fn, - bwd_fn, +def max_pool2d_bwd_wrapper( a: TensorProxy, /, kernel_size: int | Sequence[int], @@ -1566,10 +1564,31 @@ def max_pool_bwd_wrapper( return_indices: bool = False, ceil_mode: bool = False, ): - primals = fwd_fn(a, kernel_size, stride, padding, dilation, ceil_mode) + primals = max_pool2d_with_indices(a, kernel_size, stride, padding, dilation, ceil_mode) grad = get_grad(primals[0]) - grad_a = bwd_fn(grad, a, kernel_size, stride, padding, dilation, ceil_mode, primals[1]) + grad_a = max_pool2d_with_indices_backward(grad, a, kernel_size, stride, padding, dilation, ceil_mode, primals[1]) + put_grad(a, grad_a) + + if return_indices: + return primals + else: + return primals[0] + +def max_pool3d_bwd_wrapper( + a: TensorProxy, + /, + kernel_size: int | Sequence[int], + stride: int | Sequence[int] | None = None, + padding: int | Sequence[int] = 0, + dilation: int | Sequence[int] = 1, + return_indices: bool = False, + ceil_mode: bool = False, +): + primals = max_pool3d_with_indices(a, kernel_size, stride, padding, dilation, ceil_mode) + + grad = get_grad(primals[0]) + grad_a = max_pool3d_with_indices_backward(grad, a, kernel_size, stride, padding, dilation, ceil_mode, primals[1]) put_grad(a, grad_a) if return_indices: @@ -1578,13 +1597,11 @@ def max_pool_bwd_wrapper( return primals[0] -max_pool2d_bwd_wrapper = partial(max_pool_bwd_wrapper, max_pool2d_with_indices, max_pool2d_with_indices_backward) ex.register_implementation( - ltorch.max_pool2d, max_pool2d, checker=_always_executable, grad_transform=max_pool_bwd_wrapper + ltorch.max_pool2d, max_pool2d, checker=_always_executable, grad_transform=max_pool2d_bwd_wrapper ) -max_pool3d_bwd_wrapper = partial(max_pool_bwd_wrapper, max_pool3d_with_indices, max_pool3d_with_indices_backward) ex.register_implementation( - ltorch.max_pool3d, max_pool3d, checker=_always_executable, grad_transform=max_pool_bwd_wrapper + ltorch.max_pool3d, max_pool3d, checker=_always_executable, grad_transform=max_pool3d_bwd_wrapper ) _register_implementation(ltorch.nll_loss, checker=_always_executable, execution_transform=_nll_loss_transform) nll_loss_backward = ex.register_operator( From f8a167fc8b0a0ad0f37cf0bda02fb716461d2245 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 12 Apr 2024 05:28:22 +0000 Subject: [PATCH 18/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/executors/torchex.py | 1 + 1 file changed, 1 insertion(+) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 7f82a05d56..7d299b58a9 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1575,6 +1575,7 @@ def max_pool2d_bwd_wrapper( else: return primals[0] + def max_pool3d_bwd_wrapper( a: TensorProxy, /, From dfaa0584fcf68470f102d8fd255b3323a9530198 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 12 Apr 2024 11:35:18 -0700 Subject: [PATCH 19/21] addressing reviews --- thunder/executors/torchex.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 7d299b58a9..c708bbf819 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1249,7 +1249,7 @@ def pooling_output_shape(in_, kernel_, pad_, stride_, dilation_, ceil_mode_: boo out_size -= 1 return out_size - def get_maybe_ith_entry(seq: Sequence[int], i: int, default: int = None): + def get_maybe_ith_entry(arg_name: str, seq: int | Sequence[int], i: int, default: int | None = None): if seq is None: return default @@ -1259,15 +1259,22 @@ def get_maybe_ith_entry(seq: Sequence[int], i: int, default: int = None): if len(seq) == 1: return seq[0] else: + utils.check( + i < len(seq), + lambda: f"invalid pooling argument: {arg_name} needs to be None / a scalar / size-{ndim} Sequence, but received {seq}", + ) return seq[i] out_sizes = [] for i in range(ndim): in_ = a.shape[i - ndim] # i - ndim is the i-th spatial dimension - kernel_ = get_maybe_ith_entry(kernel_size, i) - stride_ = get_maybe_ith_entry(stride, i, kernel_) - pad_ = get_maybe_ith_entry(padding, i) - dilation_ = get_maybe_ith_entry(dilation, i) + kernel_ = get_maybe_ith_entry("kernel_size", kernel_size, i) + stride_ = get_maybe_ith_entry("stride", stride, i, kernel_) + pad_ = get_maybe_ith_entry("padding", padding, i) + dilation_ = get_maybe_ith_entry("dilation", dilation, i) + utils.check( + kernel_ is not None and stride_ is not None and pad_ is not None and dilation_ is not None, + lambda: f"max_pool argument extraction failed.") out_sizes.append(pooling_output_shape(in_, kernel_, pad_, stride_, dilation_, ceil_mode)) return TensorProxy(like=a, shape=out_sizes), TensorProxy(like=a, shape=out_sizes) @@ -1563,7 +1570,7 @@ def max_pool2d_bwd_wrapper( dilation: int | Sequence[int] = 1, return_indices: bool = False, ceil_mode: bool = False, -): +): tuple[TensorProxy, TensorProxy] | TensorProxy: primals = max_pool2d_with_indices(a, kernel_size, stride, padding, dilation, ceil_mode) grad = get_grad(primals[0]) @@ -1585,7 +1592,7 @@ def max_pool3d_bwd_wrapper( dilation: int | Sequence[int] = 1, return_indices: bool = False, ceil_mode: bool = False, -): +): tuple[TensorProxy, TensorProxy] | TensorProxy: primals = max_pool3d_with_indices(a, kernel_size, stride, padding, dilation, ceil_mode) grad = get_grad(primals[0]) @@ -1598,6 +1605,7 @@ def max_pool3d_bwd_wrapper( return primals[0] +# ltorch.max_pool2d/3d decomposition uses convolution, which has performance issue running through torchex. We added grad_transform that keep both forward and backward max_pool as a torch composite operator, which avoids the performance issue. For details: https://github.com/Lightning-AI/lightning-thunder/issues/164. Aten doesn't have explicit functions for max_pool1d fwd/bwd. So the specialization is only done for 2d/3d case. ex.register_implementation( ltorch.max_pool2d, max_pool2d, checker=_always_executable, grad_transform=max_pool2d_bwd_wrapper ) From 155ac78ec06e06b41e854c0e33e276f54793d558 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 12 Apr 2024 11:36:38 -0700 Subject: [PATCH 20/21] typo --- thunder/executors/torchex.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index c708bbf819..2d2825b0ab 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1570,7 +1570,7 @@ def max_pool2d_bwd_wrapper( dilation: int | Sequence[int] = 1, return_indices: bool = False, ceil_mode: bool = False, -): tuple[TensorProxy, TensorProxy] | TensorProxy: +) -> tuple[TensorProxy, TensorProxy] | TensorProxy: primals = max_pool2d_with_indices(a, kernel_size, stride, padding, dilation, ceil_mode) grad = get_grad(primals[0]) @@ -1592,7 +1592,7 @@ def max_pool3d_bwd_wrapper( dilation: int | Sequence[int] = 1, return_indices: bool = False, ceil_mode: bool = False, -): tuple[TensorProxy, TensorProxy] | TensorProxy: +) -> tuple[TensorProxy, TensorProxy] | TensorProxy: primals = max_pool3d_with_indices(a, kernel_size, stride, padding, dilation, ceil_mode) grad = get_grad(primals[0]) From 2a8da1b7d9dfa95bb6aa748cf96bc3ffe64418dc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 12 Apr 2024 18:37:29 +0000 Subject: [PATCH 21/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/executors/torchex.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 2d2825b0ab..bd991da611 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1274,7 +1274,8 @@ def get_maybe_ith_entry(arg_name: str, seq: int | Sequence[int], i: int, default dilation_ = get_maybe_ith_entry("dilation", dilation, i) utils.check( kernel_ is not None and stride_ is not None and pad_ is not None and dilation_ is not None, - lambda: f"max_pool argument extraction failed.") + lambda: f"max_pool argument extraction failed.", + ) out_sizes.append(pooling_output_shape(in_, kernel_, pad_, stride_, dilation_, ceil_mode)) return TensorProxy(like=a, shape=out_sizes), TensorProxy(like=a, shape=out_sizes)