diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 6345e2780f..bd991da611 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -37,6 +37,11 @@ from thunder.extend import OperatorExecutor, register_executor, add_always_executor +from thunder.core.transforms import ( + get_grad, + put_grad, +) + ex = OperatorExecutor("torch", version=torch.__version__) register_executor(ex) add_always_executor(ex) @@ -1217,6 +1222,100 @@ def _take_along_axis_prim_transform(a: TensorProxy, /, index: TensorProxy, dim: max_pool1d = _register_torch_operation("max_pool1d", module=torch.nn.functional) max_pool2d = _register_torch_operation("max_pool2d", module=torch.nn.functional) max_pool3d = _register_torch_operation("max_pool3d", module=torch.nn.functional) + + +def _max_pool_with_indices_helper( + ndim: int, + a: TensorProxy, + /, + kernel_size: int | Sequence[int], + stride: int | Sequence[int] | None = None, + padding: int | Sequence[int] = 0, + dilation: int | Sequence[int] = 1, + ceil_mode: bool = False, +) -> [TensorProxy, TensorProxy]: + def div_rtn(x, y): + q = x // y + r = x % y + if r != 0 and (r < 0) != (y < 0): + q -= 1 + return q + + def pooling_output_shape(in_, kernel_, pad_, stride_, dilation_, ceil_mode_: bool): + out_size = ( + div_rtn(in_ + 2 * pad_ - dilation_ * (kernel_ - 1) - 1 + (stride - 1 if ceil_mode else 0), stride) + 1 + ) + if ceil_mode and (out_size - 1) * stride >= in_ + pad_: + out_size -= 1 + return out_size + + def get_maybe_ith_entry(arg_name: str, seq: int | Sequence[int], i: int, default: int | None = None): + if seq is None: + return default + + if not isinstance(seq, Sequence): + return seq + + if len(seq) == 1: + return seq[0] + else: + utils.check( + i < len(seq), + lambda: f"invalid pooling argument: {arg_name} needs to be None / a scalar / size-{ndim} Sequence, but received {seq}", + ) + return seq[i] + + out_sizes = [] + for i in range(ndim): + in_ = a.shape[i - ndim] # i - ndim is the i-th spatial dimension + kernel_ = get_maybe_ith_entry("kernel_size", kernel_size, i) + stride_ = get_maybe_ith_entry("stride", stride, i, kernel_) + pad_ = get_maybe_ith_entry("padding", padding, i) + dilation_ = get_maybe_ith_entry("dilation", dilation, i) + utils.check( + kernel_ is not None and stride_ is not None and pad_ is not None and dilation_ is not None, + lambda: f"max_pool argument extraction failed.", + ) + out_sizes.append(pooling_output_shape(in_, kernel_, pad_, stride_, dilation_, ceil_mode)) + + return TensorProxy(like=a, shape=out_sizes), TensorProxy(like=a, shape=out_sizes) + + +def max_pool_with_indices_backward_meta( + grad: TensorProxy, + a: TensorProxy, + kernel_size: int | Sequence[int], + stride: int | Sequence[int] | None, + padding: int | Sequence[int], + dilation: int | Sequence[int], + ceil_mode: bool, + result1: TensorProxy, +) -> TensorProxy: + return TensorProxy(like=a) + + +max_pool2d_with_indices_meta = partial(_max_pool_with_indices_helper, 2) + +max_pool2d_with_indices = ex.register_operator( + "max_pool2d_with_indices", meta=max_pool2d_with_indices_meta, fn=torch.ops.aten.max_pool2d_with_indices +) +max_pool2d_with_indices_backward = ex.register_operator( + "max_pool2d_with_indices_backward", + meta=max_pool_with_indices_backward_meta, + fn=torch.ops.aten.max_pool2d_with_indices_backward, +) + +max_pool3d_with_indices_meta = partial(_max_pool_with_indices_helper, 3) + +max_pool3d_with_indices = ex.register_operator( + "max_pool3d_with_indices", meta=max_pool3d_with_indices_meta, fn=torch.ops.aten.max_pool3d_with_indices +) +max_pool3d_with_indices_backward = ex.register_operator( + "max_pool3d_with_indices_backward", + meta=max_pool_with_indices_backward_meta, + fn=torch.ops.aten.max_pool3d_with_indices_backward, +) + nll_loss = _register_torch_operation("nll_loss", module=torch.nn.functional) pad = _register_torch_operation("pad", module=torch.nn.functional) scaled_dot_product_attention = _register_torch_operation("scaled_dot_product_attention", module=torch.nn.functional) @@ -1461,8 +1560,59 @@ def _pad_prim_impl( ltorch.log_softmax_backward, checker=_always_executable, execution_transform=_log_softmax_backward_transform ) _register_implementation(ltorch.max_pool1d, max_pool1d, checker=_always_executable) -_register_implementation(ltorch.max_pool2d, max_pool2d, checker=_always_executable) -_register_implementation(ltorch.max_pool3d, max_pool3d, checker=_always_executable) + + +def max_pool2d_bwd_wrapper( + a: TensorProxy, + /, + kernel_size: int | Sequence[int], + stride: int | Sequence[int] | None = None, + padding: int | Sequence[int] = 0, + dilation: int | Sequence[int] = 1, + return_indices: bool = False, + ceil_mode: bool = False, +) -> tuple[TensorProxy, TensorProxy] | TensorProxy: + primals = max_pool2d_with_indices(a, kernel_size, stride, padding, dilation, ceil_mode) + + grad = get_grad(primals[0]) + grad_a = max_pool2d_with_indices_backward(grad, a, kernel_size, stride, padding, dilation, ceil_mode, primals[1]) + put_grad(a, grad_a) + + if return_indices: + return primals + else: + return primals[0] + + +def max_pool3d_bwd_wrapper( + a: TensorProxy, + /, + kernel_size: int | Sequence[int], + stride: int | Sequence[int] | None = None, + padding: int | Sequence[int] = 0, + dilation: int | Sequence[int] = 1, + return_indices: bool = False, + ceil_mode: bool = False, +) -> tuple[TensorProxy, TensorProxy] | TensorProxy: + primals = max_pool3d_with_indices(a, kernel_size, stride, padding, dilation, ceil_mode) + + grad = get_grad(primals[0]) + grad_a = max_pool3d_with_indices_backward(grad, a, kernel_size, stride, padding, dilation, ceil_mode, primals[1]) + put_grad(a, grad_a) + + if return_indices: + return primals + else: + return primals[0] + + +# ltorch.max_pool2d/3d decomposition uses convolution, which has performance issue running through torchex. We added grad_transform that keep both forward and backward max_pool as a torch composite operator, which avoids the performance issue. For details: https://github.com/Lightning-AI/lightning-thunder/issues/164. Aten doesn't have explicit functions for max_pool1d fwd/bwd. So the specialization is only done for 2d/3d case. +ex.register_implementation( + ltorch.max_pool2d, max_pool2d, checker=_always_executable, grad_transform=max_pool2d_bwd_wrapper +) +ex.register_implementation( + ltorch.max_pool3d, max_pool3d, checker=_always_executable, grad_transform=max_pool3d_bwd_wrapper +) _register_implementation(ltorch.nll_loss, checker=_always_executable, execution_transform=_nll_loss_transform) nll_loss_backward = ex.register_operator( "torch_nll_loss_backward_impl", meta=ltorch.nll_loss_backward, fn=_nll_loss_backward_impl