From 54502a1789335ab43f8c65bfbf9cd8e8d01ae5ad Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Tue, 30 Jan 2024 15:05:26 +0100 Subject: [PATCH] Bbonev/disco refactor (#29) * Cleaned up DISCO convolutions --- Changelog.md | 9 +- torch_harmonics/__init__.py | 2 +- torch_harmonics/convolution.py | 259 ++++++++++++++++++++------------- torch_harmonics/quadrature.py | 43 +++++- 4 files changed, 196 insertions(+), 117 deletions(-) diff --git a/Changelog.md b/Changelog.md index d5c0e75..6c93363 100644 --- a/Changelog.md +++ b/Changelog.md @@ -4,10 +4,11 @@ ### v0.6.5 -* Discrrete-continuous (DISCO) convolutions on the sphere -* Isotropic and anisotropic DISCO convolutions -* Accelerated DISCO convolutions on GPU via Triton implementation -* Unittests for DISCO convolutions +* Discrete-continuous (DISCO) convolutions on the sphere and in two dimensions +* DISCO supports isotropic and anisotropic kernel functions parameterized as hat functions +* Supports regular and transpose convolutions +* Accelerated spherical DISCO convolutions on GPU via Triton implementation +* Unittests for DISCO convolutions in `tests/test_convolution.py` ### v0.6.4 diff --git a/torch_harmonics/__init__.py b/torch_harmonics/__init__.py index 3366517..6bf8c4f 100644 --- a/torch_harmonics/__init__.py +++ b/torch_harmonics/__init__.py @@ -29,7 +29,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # -__version__ = '0.6.4' +__version__ = '0.6.5' from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2 diff --git a/torch_harmonics/convolution.py b/torch_harmonics/convolution.py index 363389e..834c3ea 100644 --- a/torch_harmonics/convolution.py +++ b/torch_harmonics/convolution.py @@ -29,6 +29,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # +import abc from typing import List, Tuple, Union, Optional import math @@ -38,7 +39,7 @@ from functools import partial -from torch_harmonics.quadrature import _precompute_latitudes +from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes from torch_harmonics._disco_convolution import ( _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch, @@ -47,50 +48,67 @@ ) -def _compute_support_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, theta_cutoff: float): +def _compute_support_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, r_cutoff: float, norm: str = "s2"): """ Computes the index set that falls into the isotropic kernel's support and returns both indices and values. """ # compute the support - dtheta = (theta_cutoff - 0.0) / ntheta - ikernel = torch.arange(ntheta).reshape(-1, 1, 1) - itheta = ikernel * dtheta - - norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff - dtheta) + math.cos(theta_cutoff - dtheta) + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta) + dr = (r_cutoff - 0.0) / nr + ikernel = torch.arange(nr).reshape(-1, 1, 1) + ir = ikernel * dr + + if norm == "none": + norm_factor = 1.0 + elif norm == "2d": + norm_factor = math.pi * (r_cutoff * nr / (nr + 1))**2 + math.pi * r_cutoff**2 * (2 * nr / (nr + 1) + 1) / (nr + 1) / 3 + elif norm == "s2": + norm_factor = 2 * math.pi * (1 - math.cos(r_cutoff - dr) + math.cos(r_cutoff - dr) + (math.sin(r_cutoff - dr) - math.sin(r_cutoff)) / dr) + else: + raise ValueError(f"Unknown normalization mode {norm}.") # find the indices where the rotated position falls into the support of the kernel - iidx = torch.argwhere(((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff)) - vals = (1 - (theta[iidx[:, 1], iidx[:, 2]] - itheta[iidx[:, 0], 0, 0]).abs() / dtheta) / norm_factor + iidx = torch.argwhere(((r - ir).abs() <= dr) & (r <= r_cutoff)) + vals = (1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr) / norm_factor return iidx, vals -def _compute_support_vals_anisotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, nphi: int, theta_cutoff: float): + +def _compute_support_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi: int, r_cutoff: float, norm: str = "s2"): """ Computes the index set that falls into the anisotropic kernel's support and returns both indices and values. """ # compute the support - dtheta = (theta_cutoff - 0.0) / ntheta + dr = (r_cutoff - 0.0) / nr dphi = 2.0 * math.pi / nphi - kernel_size = (ntheta-1)*nphi + 1 + kernel_size = (nr - 1) * nphi + 1 ikernel = torch.arange(kernel_size).reshape(-1, 1, 1) - itheta = ((ikernel - 1) // nphi + 1) * dtheta + ir = ((ikernel - 1) // nphi + 1) * dr iphi = ((ikernel - 1) % nphi) * dphi - norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff - dtheta) + math.cos(theta_cutoff - dtheta) + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta) + if norm == "none": + norm_factor = 1.0 + elif norm == "2d": + norm_factor = math.pi * (r_cutoff * nr / (nr + 1))**2 + math.pi * r_cutoff**2 * (2 * nr / (nr + 1) + 1) / (nr + 1) / 3 + elif norm == "s2": + norm_factor = 2 * math.pi * (1 - math.cos(r_cutoff - dr) + math.cos(r_cutoff - dr) + (math.sin(r_cutoff - dr) - math.sin(r_cutoff)) / dr) + else: + raise ValueError(f"Unknown normalization mode {norm}.") # find the indices where the rotated position falls into the support of the kernel - cond_theta = ((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff) - cond_phi = (ikernel == 0) | ((phi - iphi).abs() <= dphi) | ((2*math.pi - (phi - iphi).abs()) <= dphi) - iidx = torch.argwhere(cond_theta & cond_phi) - vals = (1 - (theta[iidx[:, 1], iidx[:, 2]] - itheta[iidx[:, 0], 0, 0]).abs() / dtheta) / norm_factor - vals *= torch.where(iidx[:, 0] > 0, (1 - torch.minimum((phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs(), (2*math.pi - (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()) ) / dphi ), 1.0) + cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff) + cond_phi = (ikernel == 0) | ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi) + iidx = torch.argwhere(cond_r & cond_phi) + vals = (1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr) / norm_factor + vals *= torch.where( + iidx[:, 0] > 0, + (1 - torch.minimum((phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs(), (2 * math.pi - (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs())) / dphi), + 1.0, + ) return iidx, vals -def _precompute_convolution_tensor( - in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi -): +def _precompute_convolution_tensor_s2(in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi): """ Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$. Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al. @@ -111,9 +129,9 @@ def _precompute_convolution_tensor( assert len(out_shape) == 2 if len(kernel_shape) == 1: - kernel_handle = partial(_compute_support_vals_isotropic, ntheta=kernel_shape[0], theta_cutoff=theta_cutoff) + kernel_handle = partial(_compute_support_vals_isotropic, nr=kernel_shape[0], r_cutoff=theta_cutoff, norm="s2") elif len(kernel_shape) == 2: - kernel_handle = partial(_compute_support_vals_anisotropic, ntheta=kernel_shape[0], nphi=kernel_shape[1], theta_cutoff=theta_cutoff) + kernel_handle = partial(_compute_support_vals_anisotropic, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=theta_cutoff, norm="s2") else: raise ValueError("kernel_shape should be either one- or two-dimensional.") @@ -131,24 +149,24 @@ def _precompute_convolution_tensor( # compute the phi differences # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0 - lons_in = torch.linspace(0, 2*math.pi, nlon_in+1)[:-1] + lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1] for t in range(nlat_out): # the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis - alpha = - lats_out[t] + alpha = -lats_out[t] beta = lons_in gamma = lats_in.reshape(-1, 1) # compute cartesian coordinates of the rotated position # This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation, # and therefore applied with a negative sign - z = - torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma) + z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma) x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha) y = torch.sin(beta) * torch.sin(gamma) - + # normalization is emportant to avoid NaNs when arccos and atan are applied # this can otherwise lead to spurious artifacts in the solution - norm = torch.sqrt(x*x + y*y + z*z) + norm = torch.sqrt(x * x + y * y + z * z) x = x / norm y = y / norm z = z / norm @@ -170,9 +188,96 @@ def _precompute_convolution_tensor( return out_idx, out_vals -# TODO: -# - derive conv and conv transpose from single module -class DiscreteContinuousConvS2(nn.Module): +def _precompute_convolution_tensor_2d(grid_in, grid_out, kernel_shape, radius_cutoff=0.01, periodic=False): + """ + Precomputes the translated filters at positions $T^{-1}_j \omega_i = T^{-1}_j T_i \nu$. Similar to the S2 routine, + only that it assumes a non-periodic subset of the euclidean plane + """ + + # check that input arrays are valid point clouds in 2D + assert len(grid_in) == 2 + assert len(grid_out) == 2 + assert grid_in.shape[0] == 2 + assert grid_out.shape[0] == 2 + + n_in = grid_in.shape[-1] + n_out = grid_out.shape[-1] + + if len(kernel_shape) == 1: + kernel_handle = partial(_compute_support_vals_isotropic, nr=kernel_shape[0], r_cutoff=radius_cutoff, norm="2d") + elif len(kernel_shape) == 2: + kernel_handle = partial(_compute_support_vals_anisotropic, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=radius_cutoff, norm="2d") + else: + raise ValueError("kernel_shape should be either one- or two-dimensional.") + + grid_in = grid_in.reshape(2, 1, n_in) + grid_out = grid_out.reshape(2, n_out, 1) + + diffs = grid_in - grid_out + if periodic: + periodic_diffs = torch.where(diffs > 0.0, diffs-1, diffs+1) + diffs = torch.where(diffs.abs() < periodic_diffs.abs(), diffs, periodic_diffs) + + + r = torch.sqrt(diffs[0] ** 2 + diffs[1] ** 2) + phi = torch.arctan2(diffs[1], diffs[0]) + torch.pi + + idx, vals = kernel_handle(r, phi) + idx = idx.permute(1, 0) + + return idx, vals + + +class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta): + """ + Abstract base class for DISCO convolutions + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_shape: Union[int, List[int]], + groups: Optional[int] = 1, + bias: Optional[bool] = True, + ): + super().__init__() + + if isinstance(kernel_shape, int): + self.kernel_shape = [kernel_shape] + else: + self.kernel_shape = kernel_shape + + if len(self.kernel_shape) == 1: + self.kernel_size = self.kernel_shape[0] + elif len(self.kernel_shape) == 2: + self.kernel_size = (self.kernel_shape[0] - 1) * self.kernel_shape[1] + 1 + else: + raise ValueError("kernel_shape should be either one- or two-dimensional.") + + # groups + self.groups = groups + + # weight tensor + if in_channels % self.groups != 0: + raise ValueError("Error, the number of input channels has to be an integer multiple of the group size") + if out_channels % self.groups != 0: + raise ValueError("Error, the number of output channels has to be an integer multiple of the group size") + self.groupsize = in_channels // self.groups + scale = math.sqrt(1.0 / self.groupsize) + self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channels)) + else: + self.bias = None + + @abc.abstractmethod + def forward(self, x: torch.Tensor): + raise NotImplementedError + + +class DiscreteContinuousConvS2(DiscreteContinuousConv): """ Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1]. @@ -192,24 +297,14 @@ def __init__( bias: Optional[bool] = True, theta_cutoff: Optional[float] = None, ): - super().__init__() + super().__init__(in_channels, out_channels, kernel_shape, groups, bias) self.nlat_in, self.nlon_in = in_shape self.nlat_out, self.nlon_out = out_shape - if isinstance(kernel_shape, int): - kernel_shape = [kernel_shape] - if len(kernel_shape) == 1: - self.kernel_size = kernel_shape[0] - elif len(kernel_shape) == 2: - self.kernel_size = (kernel_shape[0]-1)*kernel_shape[1] + 1 - else: - raise ValueError("kernel_shape should be either one- or two-dimensional.") - - # compute theta cutoff based on the bandlimit of the input field if theta_cutoff is None: - theta_cutoff = (kernel_shape[0]+1) * torch.pi / float(self.nlat_in - 1) + theta_cutoff = (self.kernel_shape[0] + 1) * torch.pi / float(self.nlat_in - 1) if theta_cutoff <= 0.0: raise ValueError("Error, theta_cutoff has to be positive.") @@ -219,38 +314,20 @@ def __init__( quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / self.nlon_in self.register_buffer("quad_weights", quad_weights, persistent=False) - idx, vals = _precompute_convolution_tensor( - in_shape, out_shape, kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff - ) - # psi = torch.sparse_coo_tensor( - # idx, vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in) - # ).coalesce() + idx, vals = _precompute_convolution_tensor_s2(in_shape, out_shape, self.kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff) + self.register_buffer("psi_idx", idx, persistent=False) self.register_buffer("psi_vals", vals, persistent=False) - # self.register_buffer("psi", psi, persistent=False) - - # groups - self.groups = groups - - # weight tensor - if in_channels % self.groups != 0: - raise ValueError("Error, the number of input channels has to be an integer multiple of the group size") - if out_channels % self.groups != 0: - raise ValueError("Error, the number of output channels has to be an integer multiple of the group size") - self.groupsize = in_channels // self.groups - scale = math.sqrt(1.0 / self.groupsize) - self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size)) - if bias: - self.bias = nn.Parameter(torch.zeros(out_channels)) - else: - self.bias = None + def get_psi(self): + psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)).coalesce() + return psi def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor: # pre-multiply x with the quadrature weights x = self.quad_weights * x - psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)).coalesce() + psi = self.get_psi() if x.is_cuda and use_triton_kernel: x = _disco_s2_contraction_triton(x, psi, self.nlon_out) @@ -271,7 +348,7 @@ def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tens return out -class DiscreteContinuousConvTransposeS2(nn.Module): +class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv): """ Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1]. @@ -291,23 +368,14 @@ def __init__( bias: Optional[bool] = True, theta_cutoff: Optional[float] = None, ): - super().__init__() + super().__init__(in_channels, out_channels, kernel_shape, groups, bias) self.nlat_in, self.nlon_in = in_shape self.nlat_out, self.nlon_out = out_shape - if isinstance(kernel_shape, int): - kernel_shape = [kernel_shape] - if len(kernel_shape) == 1: - self.kernel_size = kernel_shape[0] - elif len(kernel_shape) == 2: - self.kernel_size = (kernel_shape[0]-1)*kernel_shape[1] + 1 - else: - raise ValueError("kernel_shape should be either one- or two-dimensional.") - # bandlimit if theta_cutoff is None: - theta_cutoff = (kernel_shape[0]+1) * torch.pi / float(self.nlat_in - 1) + theta_cutoff = (self.kernel_shape[0] + 1) * torch.pi / float(self.nlat_in - 1) if theta_cutoff <= 0.0: raise ValueError("Error, theta_cutoff has to be positive.") @@ -318,32 +386,14 @@ def __init__( self.register_buffer("quad_weights", quad_weights, persistent=False) # switch in_shape and out_shape since we want transpose conv - idx, vals = _precompute_convolution_tensor( - out_shape, in_shape, kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff - ) - # psi = torch.sparse_coo_tensor( - # idx, vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out) - # ).coalesce() + idx, vals = _precompute_convolution_tensor_s2(out_shape, in_shape, self.kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff) + self.register_buffer("psi_idx", idx, persistent=False) self.register_buffer("psi_vals", vals, persistent=False) - # self.register_buffer("psi", psi, persistent=False) - - # groups - self.groups = groups - # weight tensor - if in_channels % self.groups != 0: - raise ValueError("Error, the number of input channels has to be an integer multiple of the group size") - if out_channels % self.groups != 0: - raise ValueError("Error, the number of output channels has to be an integer multiple of the group size") - self.groupsize = in_channels // self.groups - scale = math.sqrt(1.0 / self.groupsize) - self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size)) - - if bias: - self.bias = nn.Parameter(torch.zeros(out_channels)) - else: - self.bias = None + def get_psi(self): + psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)).coalesce() + return psi def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor: # extract shape @@ -357,7 +407,7 @@ def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tens # pre-multiply x with the quadrature weights x = self.quad_weights * x - psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)).coalesce() + psi = self.get_psi() if x.is_cuda and use_triton_kernel: out = _disco_s2_transpose_contraction_triton(x, psi, self.nlon_out) @@ -368,3 +418,4 @@ def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tens out = out + self.bias.reshape(1, -1, 1, 1) return out + diff --git a/torch_harmonics/quadrature.py b/torch_harmonics/quadrature.py index 3222315..b2abfe5 100644 --- a/torch_harmonics/quadrature.py +++ b/torch_harmonics/quadrature.py @@ -31,26 +31,53 @@ import numpy as np +def _precompute_grid(n, grid="equidistant", a=0.0, b=1.0, periodic=False): + + if (grid != "equidistant") and periodic: + raise ValueError(f"Periodic grid is only supported on equidistant grids.") + + # compute coordinates + if grid == "equidistant": + xlg, wlg = trapezoidal_weights(n, a=a, b=b, periodic=periodic) + elif grid == "legendre-gauss": + xlg, wlg = legendre_gauss_weights(n, a=a, b=b) + elif grid == "lobatto": + xlg, wlg = lobatto_weights(n, a=a, b=b) + elif grid == "equiangular": + xlg, wlg = clenshaw_curtiss_weights(n, a=a, b=b) + else: + raise ValueError(f"Unknown grid type {grid}") + + return xlg, wlg + def _precompute_latitudes(nlat, grid="equiangular"): r""" Convenience routine to precompute latitudes """ # compute coordinates - if grid == "legendre-gauss": - xlg, wlg = legendre_gauss_weights(nlat) - elif grid == "lobatto": - xlg, wlg = lobatto_weights(nlat) - elif grid == "equiangular": - xlg, wlg = clenshaw_curtiss_weights(nlat) - else: - raise ValueError("Unknown grid") + xlg, wlg = _precompute_grid(nlat, grid=grid, a=-1.0, b=1.0, periodic=False) lats = np.flip(np.arccos(xlg)).copy() wlg = np.flip(wlg).copy() return lats, wlg +def trapezoidal_weights(n, a=-1.0, b=1.0, periodic=False): + r""" + Helper routine which returns equidistant nodes with trapezoidal weights + on the interval [a, b] + """ + + xlg = np.linspace(a, b, n) + wlg = (b - a) / (n - 1) * np.ones(n) + + if not periodic: + wlg[0] *= 0.5 + wlg[-1] *= 0.5 + + return xlg, wlg + def legendre_gauss_weights(n, a=-1.0, b=1.0): r""" Helper routine which returns the Legendre-Gauss nodes and weights