Skip to content

Commit

Permalink
wip: nonlinear inversion
Browse files Browse the repository at this point in the history
  • Loading branch information
mcencini committed Apr 4, 2024
1 parent e7e14ad commit 864a429
Show file tree
Hide file tree
Showing 11 changed files with 457 additions and 127 deletions.
10 changes: 4 additions & 6 deletions src/deepmr/nlops/enlive.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,22 @@
# import torch.nn as nn

# class Base(nn.Module):

# def __init__(self, linop, linop_grad):
# r"""
# Initiate the linear operator.
# """
# super().__init__()

# # compute terms
# self.F = linop
# self.DF = linop_grad

# def forward(self, x, dx):
# return self.F(x) + dx * self.DF(x)

# class NLINVAdjoint(nn.Module):

# def __init__(self):
# r"""
# Initiate the linear operator.
Expand All @@ -31,5 +31,3 @@

# def forward(self):
# pass


16 changes: 15 additions & 1 deletion src/deepmr/prox/llr.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class LLRDenoiser(nn.Module):
device : str, optional
Device on which the wavelet transform is computed.
The default is ``None`` (infer from input).
offset : torch.Tensor, optional
Offset applied to regularization input, i.e. ``output = W(input + offset)``
Must be either a scalar or its shape must support broadcast with ``input``.
"""

Expand All @@ -53,6 +56,7 @@ def __init__(
rand_shift=True,
axis=None,
device=None,
offset=None,
):
super().__init__()

Expand All @@ -73,15 +77,25 @@ def __init__(
else:
self.axis = axis
self.device = device

if offset is not None:
self.offset = torch.as_tensor(offset)
else:
self.offset = None

def forward(self, x):
def forward(self, input):
x = input
# default device
idevice = x.device
if self.device is None:
device = idevice
else:
device = self.device
x = x.to(device)

# apply offset
if self.offset is not None:
x = x.to(device) + self.offset.to(device)

# circshift randomly
if self.rand_shift is True:
Expand Down
17 changes: 15 additions & 2 deletions src/deepmr/prox/tgv.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class TGVDenoiser(nn.Module):
Dual variable for warm restart. The default is ``None``.
r2 : torch.Tensor, optional
Auxiliary variable for warm restart. The default is ``None``.
offset : torch.Tensor, optional
Offset applied to regularization input, i.e. ``output = W(input + offset)``
Must be either a scalar or its shape must support broadcast with ``input``.
Notes
-----
Expand All @@ -75,6 +78,7 @@ def __init__(
x2=None,
u2=None,
r2=None,
offset=None,
):
super().__init__()

Expand All @@ -94,6 +98,11 @@ def __init__(
r2=r2,
)
self.denoiser.device = device

if offset is not None:
self.offset = torch.as_tensor(offset)
else:
self.offset = None

def forward(self, input):
# get complex
Expand All @@ -112,15 +121,19 @@ def forward(self, input):
# get input shape
ndim = self.denoiser.ndim
ishape = input.shape

# apply offset
if self.offset is not None:
input = input.to(device) + self.offset.to(device)

# reshape for computation
input = input.reshape(-1, *ishape[-ndim:])
if iscomplex:
input = torch.stack((input.real, input.imag), axis=1)
input = input.reshape(-1, *ishape[-ndim:])

# apply denoising
output = self.denoiser(input[:, None, ...].to(device), self.ths).to(
output = self.denoiser(input.to(device), self.ths).to(
idevice
) # perform the denoising on the real-valued tensor

Expand Down
15 changes: 14 additions & 1 deletion src/deepmr/prox/tv.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ class TVDenoiser(nn.Module):
Primary variable for warm restart. The default is ``None``.
u2 : torch.Tensor, optional
Dual variable for warm restart. The default is ``None``.
offset : torch.Tensor, optional
Offset applied to regularization input, i.e. ``output = W(input + offset)``
Must be either a scalar or its shape must support broadcast with ``input``.
Notes
-----
Expand All @@ -72,6 +75,7 @@ def __init__(
crit=1e-5,
x2=None,
u2=None,
offset=None,
):
super().__init__()

Expand All @@ -90,6 +94,11 @@ def __init__(
u2=u2,
)
self.denoiser.device = device

if offset is not None:
self.offset = torch.as_tensor(offset)
else:
self.offset = None

def forward(self, input):
# get complex
Expand All @@ -108,6 +117,10 @@ def forward(self, input):
# get input shape
ndim = self.denoiser.ndim
ishape = input.shape

# apply offset
if self.offset is not None:
input = input.to(device) + self.offset.to(device)

# reshape for computation
input = input.reshape(-1, *ishape[-ndim:])
Expand All @@ -116,7 +129,7 @@ def forward(self, input):
input = input.reshape(-1, *ishape[-ndim:])

# apply denoising
output = self.denoiser(input[:, None, ...].to(device), self.ths).to(
output = self.denoiser(input.to(device), self.ths).to(
idevice
) # perform the denoising on the real-valued tensor

Expand Down
30 changes: 28 additions & 2 deletions src/deepmr/prox/wavelet.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class WaveletDenoiser(nn.Module):
The default is ``"soft"``.
level: int, optional
Level of the wavelet transform. The default is ``None``.
offset : torch.Tensor, optional
Offset applied to regularization input, i.e. ``output = W(input + offset)``
Must be either a scalar or its shape must support broadcast with ``input``.
"""

Expand All @@ -80,6 +83,7 @@ def __init__(
device=None,
non_linearity="soft",
level=None,
offset=None,
*args,
**kwargs
):
Expand All @@ -100,6 +104,11 @@ def __init__(
**kwargs
)
self.denoiser.device = device

if offset is not None:
self.offset = torch.as_tensor(offset)
else:
self.offset = None

def forward(self, input):
# get complex
Expand All @@ -118,6 +127,10 @@ def forward(self, input):
# get input shape
ndim = self.denoiser.dimension
ishape = input.shape

# apply offset
if self.offset is not None:
input = input.to(device) + self.offset.to(device)

# reshape for computation
input = input.reshape(-1, *ishape[-ndim:])
Expand All @@ -126,7 +139,7 @@ def forward(self, input):
input = input.reshape(-1, *ishape[-ndim:])

# apply denoising
output = self.denoiser(input[:, None, ...].to(device), self.ths).to(
output = self.denoiser(input.to(device), self.ths).to(
idevice
) # perform the denoising on the real-valued tensor

Expand Down Expand Up @@ -249,6 +262,9 @@ class WaveletDictDenoiser(nn.Module):
max_iter : int, optional
Number of iterations of the optimization algorithm.
The default is ``10``.
offset : torch.Tensor, optional
Offset applied to regularization input, i.e. ``output = W(input + offset)``
Must be either a scalar or its shape must support broadcast with ``input``.
"""

Expand All @@ -262,6 +278,7 @@ def __init__(
non_linearity="soft",
level=None,
max_iter=10,
offset=None,
*args,
**kwargs
):
Expand All @@ -284,6 +301,11 @@ def __init__(
)

self.denoiser.device = device

if offset is not None:
self.offset = torch.as_tensor(offset)
else:
self.offset = None

def forward(self, input):
# get complex
Expand All @@ -302,6 +324,10 @@ def forward(self, input):
# get input shape
ndim = self.denoiser.dimension
ishape = input.shape

# apply offset
if self.offset is not None:
input = input.to(device) + self.offset.to(device)

# reshape for computation
input = input.reshape(-1, *ishape[-ndim:])
Expand All @@ -310,7 +336,7 @@ def forward(self, input):
input = input.reshape(-1, *ishape[-ndim:])

# apply denoising
output = self.denoiser(input[:, None, ...].to(device), self.ths).to(
output = self.denoiser(input.to(device), self.ths).to(
idevice
) # perform the denoising on the real-valued tensor

Expand Down
6 changes: 5 additions & 1 deletion src/deepmr/recon/alg/classic_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
import numpy as np
import torch


from ... import linops as _linops
from ... import optim as _optim
from ... import prox as _prox


from .. import calib as _calib
from ... import linops as _linops


from . import linop as _linop

Expand Down
Loading

0 comments on commit 864a429

Please sign in to comment.