Skip to content

Commit

Permalink
enhance: TV and TGV now adapted for arbitrary batch dimensions - adde…
Browse files Browse the repository at this point in the history
…d 1D case for temporal regularization.
  • Loading branch information
mcencini committed Apr 4, 2024
1 parent 864a429 commit eeec475
Show file tree
Hide file tree
Showing 2 changed files with 258 additions and 121 deletions.
265 changes: 178 additions & 87 deletions src/deepmr/prox/tgv.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,12 @@ class TGVDenoiser(nn.Module):
Attributes
----------
ndim : int
Number of spatial dimensions, can be either ``2`` or ``3``.
Number of spatial dimensions, can be ``1``, ``2`` or ``3``.
ths : float, optional
Denoise threshold. The default is ``0.1``.
axis : int, optional
Axis over which to perform finite difference. Used only if ``ndim == 1``,
ignored otherwise. The default is ``0``.
trainable : bool, optional
If ``True``, threshold value is trainable, otherwise it is not.
The default is ``False``.
Expand Down Expand Up @@ -70,6 +73,7 @@ def __init__(
self,
ndim,
ths=0.1,
axis=0,
trainable=False,
device=None,
verbose=False,
Expand All @@ -89,6 +93,7 @@ def __init__(

self.denoiser = _TGVDenoiser(
ndim=ndim,
axis=axis,
device=device,
verbose=verbose,
n_it_max=niter,
Expand Down Expand Up @@ -151,6 +156,7 @@ def tgv_denoise(
input,
ndim,
ths=0.1,
axis=0,
device=None,
verbose=False,
niter=100,
Expand Down Expand Up @@ -185,13 +191,12 @@ def tgv_denoise(
input : np.ndarray | torch.Tensor
Input image of shape (..., n_ndim, ..., n_0).
ndim : int
Number of spatial dimensions, can be either ``2`` or ``3``.
Number of spatial dimensions, can be ``1``, ``2`` or ``3``.
ths : float, optional
Denoise threshold. Default is ``0.1``.
ndim : int
Number of spatial dimensions, can be either ``2`` or ``3``.
ths : float, optional
Denoise threshold. The default is ``0.1``.
axis : int, optional
Axis over which to perform finite difference. Used only if ``ndim == 1``,
ignored otherwise. The default is ``0``.
trainable : bool, optional
If ``True``, threshold value is trainable, otherwise it is not.
The default is ``False``.
Expand Down Expand Up @@ -225,8 +230,8 @@ def tgv_denoise(
isnumpy = False

# initialize denoiser
TV = TGVDenoiser(ndim, ths, False, device, verbose, niter, crit, x2, u2)
output = TV(input)
TGV = TGVDenoiser(ndim, ths, axis, False, device, verbose, niter, crit, x2, u2)
output = TGV(input)

# cast back to numpy if requried
if isnumpy:
Expand All @@ -240,7 +245,8 @@ class _TGVDenoiser(nn.Module):
def __init__(
self,
ndim,
device,
axis=0,
device=None,
verbose=False,
n_it_max=1000,
crit=1e-5,
Expand All @@ -251,8 +257,14 @@ def __init__(
super().__init__()
self.device = device
self.ndim = ndim

if ndim == 2:
self.axis = axis

if ndim == 1:
self.nabla = self.nabla1
self.nabla_adjoint = self.nabla1_adjoint
self.epsilon = self.epsilon1
self.epsilon_adjoint = self.epsilon1_adjoint
elif ndim == 2:
self.nabla = self.nabla2
self.nabla_adjoint = self.nabla2_adjoint
self.epsilon = self.epsilon2
Expand Down Expand Up @@ -392,47 +404,80 @@ def forward(self, y, ths=None):

return self.x2

def nabla1(self, x):
r"""
Applies the finite differences operator associated with tensors of the same shape as x.
"""
# move selected axis upfront
x = x.swapaxes(self.axis, -1)

# perform finite difference
u = torch.zeros(list(x.shape) + [1], device=x.device, dtype=x.dtype)
u[..., :-1, 0] = u[..., :-1, 0] - x[..., :-1]
u[..., :-1, 0] = u[..., :-1, 0] + x[..., 1:]

# place axis back into original position
x = x.swapaxes(self.axis, -1)
u = u[..., 0].swapaxes(self.axis, -1)[..., None]

return u

def nabla1_adjoint(self, x):
r"""
Applies the finite differences operator associated with tensors of the same shape as x.
"""
# move selected axis upfront
x = x[..., 0].swapaxes(self.axis, -1)[..., None]

# perform finite difference
u = torch.zeros(x.shape[:-1], device=x.device, dtype=x.dtype
) # note that we just reversed left and right sides of each line to obtain the transposed operator u[..., :-1, 0] = u[..., :-1, 0] - x[..., :-1]
u[..., :-1] = u[..., :-1] - x[..., :-1, 0]
u[..., 1:] = u[..., 1:] + x[..., :-1, 0]

# place axis back into original position
x = x[..., 0].swapaxes(self.axis, -1)[..., None]
u = u.swapaxes(self.axis, -1)

return u

@staticmethod
def nabla2(x):
r"""
Applies the finite differences operator associated with tensors of the same shape as x.
"""
b, c, h, w = x.shape
u = torch.zeros((b, c, h, w, 2), device=x.device).type(x.dtype)
u[:, :, :-1, :, 0] = u[:, :, :-1, :, 0] - x[:, :, :-1]
u[:, :, :-1, :, 0] = u[:, :, :-1, :, 0] + x[:, :, 1:]
u[:, :, :, :-1, 1] = u[:, :, :, :-1, 1] - x[..., :-1]
u[:, :, :, :-1, 1] = u[:, :, :, :-1, 1] + x[..., 1:]
u = torch.zeros(list(x.shape) + [2], device=x.device, dtype=x.dtype)
u[..., :-1, :, 0] = u[..., :-1, :, 0] - x[..., :, :-1]
u[..., :-1, :, 0] = u[..., :-1, :, 0] + x[..., :, 1:]
u[..., :, :-1, 1] = u[..., :, :-1, 1] - x[..., :-1, :]
u[..., :, :-1, 1] = u[..., :, :-1, 1] + x[..., 1:, :]
return u

@staticmethod
def nabla2_adjoint(x):
r"""
Applies the adjoint of the finite difference operator.
"""
b, c, h, w = x.shape[:-1]
u = torch.zeros((b, c, h, w), device=x.device).type(
x.dtype
u = torch.zeros(x.shape[:-1], device=x.device, dtype=x.dtype
) # note that we just reversed left and right sides of each line to obtain the transposed operator
u[:, :, :-1] = u[:, :, :-1] - x[:, :, :-1, :, 0]
u[:, :, 1:] = u[:, :, 1:] + x[:, :, :-1, :, 0]
u[..., :-1] = u[..., :-1] - x[..., :-1, 1]
u[..., 1:] = u[..., 1:] + x[..., :-1, 1]
u[..., :, :-1] = u[..., :, :-1] - x[..., :-1, :, 0]
u[..., :, 1:] = u[..., :, 1:] + x[..., :-1, :, 0]
u[..., :-1, :] = u[..., :-1, :] - x[..., :, :-1, 1]
u[..., 1:, :] = u[..., 1:, :] + x[..., :, :-1, 1]
return u

@staticmethod
def nabla3(x):
r"""
Applies the finite differences operator associated with tensors of the same shape as x.
"""
b, c, d, h, w = x.shape
u = torch.zeros((b, c, d, h, w, 3), device=x.device).type(x.dtype)
u[:, :, :-1, :, :, 0] = u[:, :, :-1, :, :, 0] - x[:, :, :-1]
u[:, :, :-1, :, :, 0] = u[:, :, :-1, :, :, 0] + x[:, :, 1:]
u[:, :, :, :-1, :, 1] = u[:, :, :, :-1, :, 1] - x[:, :, :, :-1]
u[:, :, :, :-1, :, 1] = u[:, :, :, :-1, :, 1] + x[:, :, :, 1:]
u[:, :, :, :, :-1, 2] = u[:, :, :, :, :-1, 2] - x[:, :, :, :, :-1]
u[:, :, :, :, :-1, 2] = u[:, :, :, :, :-1, 2] + x[:, :, :, :, 1:]
u = torch.zeros(list(x.shape) + [3], device=x.device, dtype=x.dtype)
u[..., :-1, :, :, 0] = u[..., :-1, :, :, 0] - x[..., :, :, :-1]
u[..., :-1, :, :, 0] = u[..., :-1, :, :, 0] + x[..., :, :, 1:]
u[..., :, :-1, :, 1] = u[..., :, :-1, :, 1] - x[..., :, :-1, :]
u[..., :, :-1, :, 1] = u[..., :, :-1, :, 1] + x[..., :, 1:, :]
u[..., :, :, :-1, 2] = u[..., :, :, :-1, 2] - x[..., :-1, :, :]
u[..., :, :, :-1, 2] = u[..., :, :, :-1, 2] + x[..., 1:, :, :]

return u

Expand All @@ -441,91 +486,137 @@ def nabla3_adjoint(x):
r"""
Applies the adjoint of the finite difference operator.
"""
b, c, d, h, w = x.shape
u = torch.zeros((b, c, d, h, w), device=x.device).type(x.dtype)
u[:, :, :-1, :, 0] = u[:, :, :-1, :, 0] - x[:, :, :-1]
u[:, :, 1:, :, 0] = u[:, :, 1:, :, 0] + x[:, :, :-1]
u[:, :, :, :-1, 1] = u[:, :, :, :-1, 1] - x[:, :, :, :-1]
u[:, :, :, 1:, 1] = u[:, :, :, 1:, 1] + x[:, :, :, :-1]
u[:, :, :, :, :-1, 2] = u[:, :, :, :, :-1, 2] - x[:, :, :, :, :-1]
u[:, :, :, :, 1:, 2] = u[:, :, :, :, 1:, 2] + x[:, :, :, :, :-1]
u = torch.zeros(x.shape[:-1], device=x.device, dtype=x.dtype)
u[..., :, :, :-1] = u[..., :, :, :-1] - x[..., :-1, :, :, 0]
u[..., :, :, 1:] = u[..., :, :, 1:] + x[..., :-1, :, :, 0]
u[..., :, :-1, :] = u[..., :, :-1, :] - x[..., :, :-1, :, 1]
u[..., :, 1:, :] = u[..., :, 1:, :] + x[..., :, :-1, :, 1]
u[..., :-1, :, :] = u[..., :-1, :, :] - x[..., :, :, :-1, 2]
u[..., 1:, :, :] = u[..., 1:, :, :] + x[..., :, :, :-1, 2]

return u

def epsilon1(self, I):
r"""
Applies the jacobian of a vector field.
"""
# move selected axis upfront
I = I[..., 0].swapaxes(self.axis, -1)[..., None]

# perform finite difference
G = torch.zeros(list(I.shape[:-1]) + [1], device=I.device, dtype=I.dtype)
G[..., 1:, :, 0] = G[..., 1:, :, 0] - I[..., :-1, :, 0] # xdx
G[..., 0] = G[..., 0] + I[..., 0]

# place axis back into original position
I = I[..., 0].swapaxes(self.axis, -1)[..., None]
G = G[..., 0].swapaxes(self.axis, -1)[..., None]

return G

def epsilon1_adjoint(self, G):
r"""
Applies the adjoint of the jacobian of a vector field.
"""
# move selected axis upfront
G = G[..., 0].swapaxes(self.axis, -1)[..., None]

# perform finite difference
I = torch.zeros(list(G.shape[:-1]) + [1], device=G.device, dtype=G.dtype)
I[..., :-1, :, 0] = I[..., :-1, :, 0] - G[..., 1:, :, 0] # xdx
I[..., 0] = I[..., 0] + G[..., 0]
I[..., :-1, 0] = I[..., :-1, 0] - G[..., 1:, 1] # xdy

# place axis back into original position
I = I[..., 0].swapaxes(self.axis, -1)[..., None]
G = G[..., 0].swapaxes(self.axis, -1)[..., None]

return I

@staticmethod
def epsilon2(I): # Simplified
def epsilon2(I):
r"""
Applies the jacobian of a vector field.
"""
b, c, h, w, _ = I.shape
G = torch.zeros((b, c, h, w, 4), device=I.device).type(I.dtype)
G[:, :, 1:, :, 0] = G[:, :, 1:, :, 0] - I[:, :, :-1, :, 0] # xdy
G = torch.zeros(list(I.shape[:-1]) + [4], device=I.device, dtype=I.dtype)
G[..., 1:, :, 0] = G[..., 1:, :, 0] - I[..., :-1, :, 0] # xdx
G[..., 0] = G[..., 0] + I[..., 0]
G[..., 1:, 1] = G[..., 1:, 1] - I[..., :-1, 0] # xdx
G[..., 1:, 1] = G[..., 1:, 1] - I[..., :-1, 0] # xdy
G[..., 1:, 1] = G[..., 1:, 1] + I[..., 1:, 0]
G[..., 1:, 2] = G[..., 1:, 2] - I[..., :-1, 1] # xdx
G[..., 1:, 2] = G[..., 1:, 2] - I[..., :-1, 1] # ydx
G[..., 2] = G[..., 2] + I[..., 1]
G[:, :, :-1, :, 3] = G[:, :, :-1, :, 3] - I[:, :, :-1, :, 1] # xdy
G[:, :, :-1, :, 3] = G[:, :, :-1, :, 3] + I[:, :, 1:, :, 1]
G[..., :-1, :, 3] = G[..., :-1, :, 3] - I[..., :-1, :, 1] # ydy
G[..., :-1, :, 3] = G[..., :-1, :, 3] + I[..., 1:, :, 1]

return G

@staticmethod
def epsilon2_adjoint(G):
r"""
Applies the adjoint of the jacobian of a vector field.
"""
b, c, h, w, _ = G.shape
I = torch.zeros((b, c, h, w, 2), device=G.device).type(G.dtype)
I[:, :, :-1, :, 0] = I[:, :, :-1, :, 0] - G[:, :, 1:, :, 0]
I = torch.zeros(list(G.shape[:-1]) + [2], device=G.device, dtype=G.dtype)
I[..., :-1, :, 0] = I[..., :-1, :, 0] - G[..., 1:, :, 0] # xdx
I[..., 0] = I[..., 0] + G[..., 0]
I[..., :-1, 0] = I[..., :-1, 0] - G[..., 1:, 1]
I[..., :-1, 0] = I[..., :-1, 0] - G[..., 1:, 1] # xdy
I[..., 1:, 0] = I[..., 1:, 0] + G[..., 1:, 1]
I[..., :-1, 1] = I[..., :-1, 1] - G[..., 1:, 2]
I[..., :-1, 1] = I[..., :-1, 1] - G[..., 1:, 2] # ydx
I[..., 1] = I[..., 1] + G[..., 2]
I[:, :, :-1, :, 1] = I[:, :, :-1, :, 1] - G[:, :, :-1, :, 3]
I[:, :, 1:, :, 1] = I[:, :, 1:, :, 1] + G[:, :, :-1, :, 3]
I[..., :-1, :, 1] = I[..., :-1, :, 1] - G[..., :-1, :, 3] # ydy
I[..., 1:, :, 1] = I[..., 1:, :, 1] + G[..., :-1, :, 3]

return I

@staticmethod
def epsilon3(I): # Adapted for 3D matrices
def epsilon3(I):
r"""
Applies the jacobian of a vector field.
"""
b, c, d, h, w = I.shape
G = torch.zeros((b, c, d, h, w, 6), device=I.device).type(I.dtype)
G[:, :, :, 1:, :, 0] = G[:, :, :, 1:, :, 0] - I[:, :, :, :-1, :, 0] # xdy
G = torch.zeros(list(I.shape[:-1]) + [9], device=I.device, dtype=I.dtype)
G[..., 1:, :, :, 0] = G[..., 1:, :, :, 0] - I[..., :-1, :, :, 0] # xdx
G[..., 0] = G[..., 0] + I[..., 0]
G[..., 1:, :, 1] = G[..., 1:, :, 1] - I[..., :, :-1, 0] # xdx
G[..., 1:, :, 1] = G[..., 1:, :, 1] + I[..., :, 1:, 0]
G[..., 1:, :, 2] = G[..., 1:, :, 2] - I[..., :, :-1, 1] # xdz
G[..., 2] = G[..., 2] + I[..., :, 1, 0]
G[:, :, :, :-1, :, 3] = G[:, :, :, :-1, :, 3] - I[:, :, :, :-1, :, 1] # xdy
G[:, :, :, :-1, :, 3] = G[:, :, :, :-1, :, 3] + I[:, :, :, 1:, :, 1]
G[..., 3] = G[..., 3] + I[..., 0]
G[..., 1:, :, 4] = G[..., 1:, :, 4] - I[..., 1:, :, :-1, 2] # xdz
G[..., 4] = G[..., 4] + I[..., 1, :, :, 0]
G[:, :, :, :, :-1, 5] = G[:, :, :, :, :-1, 5] - I[:, :, :, :, :-1, 2] # xdy
G[:, :, :, 1:, :, 5] = G[:, :, :, 1:, :, 5] + I[:, :, :, :, :-1, 2]
G[..., 1:, :, 1, 1] = G[..., 1:, :, 1, 1] - I[..., :-1, :, :, 1] # xdy
G[..., 1:, :, 1, 1] = G[..., 1:, :, 1, 1] + I[..., 1:, :, :, 1]
G[..., 1:, 1, :, 2] = G[..., 1:, 1, :, 2] - I[..., :-1, :, :, 2] # xdz
G[..., 2] = G[..., 2] + I[..., 1, :, :, 2]
G[..., :-1, :, :, 3] = G[..., :-1, :, :, 3] - I[..., :-1, :, :, 3] # ydx
G[..., :-1, :, :, 3] = G[..., :-1, :, :, 3] + I[..., 1:, :, :, 3]
G[..., 1:, :, 1, 4] = G[..., 1:, :, 1, 4] - I[..., :-1, :, :, 4] # ydy
G[..., 1:, :, 1, 4] = G[..., 1:, :, 1, 4] + I[..., 1:, :, :, 4]
G[..., 1:, 1, :, 5] = G[..., 1:, 1, :, 5] - I[..., :-1, :, :, 5] # ydz
G[..., 3] = G[..., 3] + I[..., 1, :, :, 5]
G[..., :-1, 1, :, 6] = G[..., :-1, 1, :, 6] - I[..., :-1, :, :, 6] # zdx
G[..., 4] = G[..., 4] + I[..., 1, :, :, 6]
G[..., 1:, :, :, 7] = G[..., 1:, :, :, 7] - I[..., :-1, :, :, 7] # zdy
G[..., 5] = G[..., 5] + I[..., 1:, :, :, 7]
G[..., :-1, :, 1, 8] = G[..., :-1, :, 1, 8] - I[..., :-1, :, 1, 8] # zdz
G[..., 6] = G[..., 6] + I[..., 1:, :, 1, 8]

return G

@staticmethod
def epsilon3_adjoint(G): # Adapted for 3D matrices
def epsilon3_adjoint(G):
r"""
Applies the adjoint of the jacobian of a vector field.
"""
b, c, d, h, w, _ = G.shape
I = torch.zeros((b, c, d, h, w, 3), device=G.device).type(G.dtype)
I[:, :, :, :-1, :, 0] = I[:, :, :, :-1, :, 0] - G[:, :, :, 1:, :, 0]
I = torch.zeros(list(G.shape[:-1]) + [3], device=G.device, dtype=G.dtype)
I[..., :-1, :, :, 0] = I[..., :-1, :, :, 0] - G[..., 1:, :, :, 0] # xdx
I[..., 0] = I[..., 0] + G[..., 0]
I[..., :-1, :, 0] = I[..., :-1, :, 0] - G[..., 1:, :, 1]
I[..., :-1, :, 0] = I[..., :-1, :, 0] - G[..., 1:, :, 1] # xdy
I[..., 1:, :, 0] = I[..., 1:, :, 0] + G[..., 1:, :, 1]
I[..., :-1, :, 1] = I[..., :-1, :, 1] - G[..., 1:, :, 2]
I[..., 0] = I[..., 0] + G[..., 2]
I[:, :, :, :-1, :, 1] = I[:, :, :, :-1, :, 1] - G[:, :, :, :-1, :, 3]
I[:, :, :, 1:, :, 1] = I[:, :, :, 1:, :, 1] + G[:, :, :, :-1, :, 3]
I[..., 1] = I[..., 1] + G[..., 3]
I[..., :-1, :, 2] = I[..., :-1, :, 2] - G[..., 1:, :, 4]
I[..., 0] = I[..., 0] + G[..., 4]
I[:, :, :, :, :-1, 2] = I[:, :, :, :, :-1, 2] - G[:, :, :, :, :-1, 5]
I[:, :, :, 1:, :, 2] = I[:, :, :, 1:, :, 2] + G[:, :, :, :, :-1, 5]
return I
I[..., :-1, 0] = I[..., :-1, 0] - G[..., 1:, :, 2] # xdz
I[..., 2] = I[..., 2] + G[..., 2]
I[..., :-1, :, :, 1] = I[..., :-1, :, :, 1] - G[..., 1:, :, :, 3] # ydx
I[..., 1:, :, :, 1] = I[..., 1:, :, :, 1] + G[..., :-1, :, :, 3]
I[..., :-1, :, 1] = I[..., :-1, :, 1] - G[..., 1:, :, 4] # ydy
I[..., 1:, :, 1] = I[..., 1:, :, 1] + G[..., :-1, :, 4]
I[..., :-1, 1] = I[..., :-1, 1] - G[..., 1:, :, 5] # ydz
I[..., 2] = I[..., 2] + G[..., :-1, :, 5]
I[..., :, :-1, :, 2] = I[..., :, :-1, :, 2] - G[..., :, :-1, :, 6] # zdx
I[..., :, 1:, :, 2] = I[..., :, 1:, :, 2] + G[..., :, :-1, :, 6]
I[..., :, :-1, 2] = I[..., :, :-1, 2] - G[..., :, 1:, 7] # zdy
I[..., :, 1:, 2] = I[..., :, 1:, 2] + G[..., :, 1:, 7]
I[..., :-1, :-1, 2] = I[..., :-1, :-1, 2] - G[..., 1:, 1:, 8] # zdz
I[..., 1:, 1:, 2] = I[..., 1:, 1:, 2] + G[..., 1:, 1:, 8]

return I
Loading

0 comments on commit eeec475

Please sign in to comment.