Skip to content

Commit

Permalink
Merge pull request #201 from openspyrit/clean_lpgd
Browse files Browse the repository at this point in the history
create tuto lpgd
  • Loading branch information
romainphan authored Jun 11, 2024
2 parents eda7674 + 143d590 commit 57f5613
Show file tree
Hide file tree
Showing 3 changed files with 493 additions and 11 deletions.
Binary file added docs/source/fig/lpgd.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
296 changes: 285 additions & 11 deletions spyrit/core/recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,240 @@ def forward(self, x):

return x

def acquire(self, x):
r"""Simulate data acquisition
Args:
:attr:`x`: ground-truth images
Shape:
:attr:`x`: ground-truth images with shape :math:`(B,C,H,W)`
:attr:`output`: reconstructed images with concatenated noise level map with shape :math:`(BC,2,H,W)`
"""

b, c, h, w = x.shape
x = 0.5 * (x + 1)
x = torch.cat((x, self.noise_level.expand(b, 1, h, w)), dim=1)
return x

def set_noise_level(self, noise_level):
r"""Reset noise level value
Args:
:attr:`noise_level`: noise level value in the range [0, 255]
Shape:
:attr:`noise_level`: float value noise level :math:`(1)`
:attr:`output`: noise level tensor with shape :math:`(1)`
"""
self.noise_level = torch.FloatTensor([noise_level / 255.0])


# %%===========================================================================================
class PositiveParameters(nn.Module):
# ===========================================================================================
def __init__(self, params, requires_grad=True):
super(PositiveParameters, self).__init__()
self.params = torch.tensor(params, requires_grad=requires_grad)

def forward(self):
return torch.abs(self.params)


# =============================================================================
class LearnedPGD(nn.Module):
r"""Learned Proximal Gradient Descent reconstruction network.
Iterative algorithm that alternates between a gradient step and a proximal step,
where the proximal operator is learned denoiser. The update rule is given by:
:math:`x_{k+1} = prox(\hat{x_k} - step * H^T (Hx_k - y))=
denoi(\hat{x_k} - step * H^T (Hx_k - y))`
Args:
:attr:`noise`: Acquisition operator (see :class:`~spyrit.core.noise`)
:attr:`prep`: Preprocessing operator (see :class:`~spyrit.core.prep`)
:attr:`denoi` (optional): Image denoising operator
(see :class:`~spyrit.core.nnet`).
Default :class:`~spyrit.core.nnet.Identity`
:attr:`iter_stop` (int): Number of iterations of the LPGD algorithm
(commonly 3 to 10, trade-off between accuracy and speed).
Default 3 (for speed and with higher accuracy than post-processing denoising)
:attr:`step` (float): Step size of the LPGD algorithm. Default is None,
and it is estimated as the inverse of the Lipschitz constant of the gradient of the
data fidelity term.
- If :math:`meas_op.N` is available, the step size is estimated as
:math:`step=1/L=1/\text{meas_op.N}`, true for Hadamard operators.
- If not, the step size is estimated from by computing
the Lipschitz constant as the largest singular value of the
Hessians, :math:`L=\lambda_{\max}(H^TH)`. If this fails,
the step size is set to 1e-4.
:attr:`step_estimation` (bool): Default False. See :attr:`step` for details.
:attr:`step_grad` (bool): Default False. If True, the step size is learned
as a parameter of the network. Not tested yet.
:attr:`wls` (bool): Default False. If True, the data fidelity term is
modified to be the weighted least squares (WLS) term, which approximates
the Poisson likelihood. In this case, the data fidelity term is
:math:`\|Hx-y\|^2_{C^{-1}}`, where :math:`C` is the covariance matrix.
We assume that :math:`C` is diagonal, and the diagonal elements are
the measurement noise variances, estimated from :class:`~spyrit.core.prep.sigma`.
:attr:`gt` (torch.tensor): Ground-truth images. If available, the mean
squared error (MSE) is computed and logged. Default None.
:attr:`log_fidelity` (bool): Default False. If True, the data fidelity term
is logged for each iteration of the LPGD algorithm.
Input / Output:
:attr:`input`: Ground-truth images with shape :math:`(B,C,H,W)`
:attr:`output`: Reconstructed images with shape :math:`(B,C,H,W)`
Attributes:
:attr:`Acq`: Acquisition operator initialized as :attr:`noise`
:attr:`prep`: Preprocessing operator initialized as :attr:`prep`
:attr:`pinv`: Analytical reconstruction operator initialized as
:class:`~spyrit.core.recon.PseudoInverse()`
:attr:`Denoi`: Image denoising operator initialized as :attr:`denoi`
Example:
>>> B, C, H, M = 10, 1, 64, 64**2
>>> Ord = torch.ones((H,H))
>>> meas = HadamSplit(M, H, Ord)
>>> noise = NoNoise(meas)
>>> prep = SplitPoisson(1.0, M, H*H)
>>> recnet = PinvNet(noise, prep)
>>> x = torch.FloatTensor(B,C,H,H).uniform_(-1, 1)
>>> z = recnet(x)
>>> print(z.shape)
>>> print(torch.linalg.norm(x - z)/torch.linalg.norm(x))
torch.Size([10, 1, 64, 64])
tensor(5.8912e-06)
"""

def __init__(
self,
noise,
prep,
denoi=nn.Identity(),
iter_stop=3,
x0=0,
step=None,
step_estimation=False,
step_grad=False,
step_decay=1,
wls=False,
gt=None,
log_fidelity=False,
res_learn=False,
):
super().__init__()
# nn.module
self.acqu = noise
self.prep = prep
self.denoi = denoi

self.pinv = PseudoInverse()

# LPGD algo
self.x0 = x0
self.iter_stop = iter_stop
self.step = step
self.step_estimation = step_estimation
self.step_grad = step_grad
self.step_decay = step_decay
self.res_learn = res_learn

# Init step size (estimate)
self.set_stepsize(step)

# WLS
self.wls = wls

# Log fidelity
self.log_fidelity = log_fidelity

# Log MSE (Ground truth available)
if gt is not None:
self.x_gt = nn.Parameter(
torch.tensor(gt.reshape(gt.shape[0], -1)), requires_grad=False
)
else:
self.x_gt = None

def step_schedule(self, step):
if self.step_decay != 1:
step = [step * self.step_decay**i for i in range(self.iter_stop)]
elif self.iter_stop > 1:
step = [step for i in range(self.iter_stop)]
else:
step = [step]
return step

def set_stepsize(self, step):
if step is None:
# Stimate stepsize from Lipschitz constant
if hasattr(self.acqu.meas_op, "N"):
step = 1 / self.acqu.meas_op.N
else:
# Estimate step size as 1/sv_max(H^TH); if failed, set to 1e-4
self.step_estimation = True
step = 1e-4

step = self.step_schedule(step)
# step = nn.Parameter(torch.tensor(step), requires_grad=self.step_grad)
step = PositiveParameters(step, requires_grad=self.step_grad)
self.step = step

def forward(self, x):
r"""Full pipeline of reconstrcution network
Args:
:attr:`x`: ground-truth images
Shape:
:attr:`x`: ground-truth images with shape :math:`(B,C,H,W)`
:attr:`output`: reconstructed images with shape :math:`(B,C,H,W)`
Example:
>>> B, C, H, M = 10, 1, 64, 64**2
>>> Ord = torch.ones((H,H))
>>> meas = HadamSplit(M, H, Ord)
>>> noise = NoNoise(meas)
>>> prep = SplitPoisson(1.0, M, H*H)
>>> recnet = PinvNet(noise, prep)
>>> x = torch.FloatTensor(B,C,H,H).uniform_(-1, 1)
>>> z = recnet(x)
>>> print(z.shape)
>>> print(torch.linalg.norm(x - z)/torch.linalg.norm(x))
torch.Size([10, 1, 64, 64])
tensor(5.8912e-06)
"""

b, c, _, _ = x.shape

# Acquisition
x = x.view(b * c, self.acqu.meas_op.N) # shape x = [b*c,h*w] = [b*c,N]
x = self.acqu(x) # shape x = [b*c, 2*M]

# Reconstruction
x = self.reconstruct(x) # shape x = [bc, 1, h,w]
x = x.view(b, c, self.acqu.meas_op.h, self.acqu.meas_op.w)

return x

def acquire(self, x):
r"""Simulate data acquisition
Expand All @@ -1029,7 +1263,7 @@ def acquire(self, x):
>>> meas = HadamSplit(M, H, Ord)
>>> noise = NoNoise(meas)
>>> prep = SplitPoisson(1.0, M, H*H)
>>> recnet = LearnedPGD(noise, prep)
>>> recnet = PinvNet(noise, prep)
>>> x = torch.FloatTensor(B,C,H,H).uniform_(-1, 1)
>>> z = recnet.acquire(x)
>>> print(z.shape)
Expand Down Expand Up @@ -1071,7 +1305,7 @@ def cost_fun(self, x, y):
def mse_fun(self, x, x_gt):
return torch.linalg.norm(x - x_gt)

def reconstruct(self, x, exp=False):
def reconstruct(self, x):
r"""Reconstruction step of a reconstruction network
Args:
Expand All @@ -1088,7 +1322,7 @@ def reconstruct(self, x, exp=False):
>>> meas = HadamSplit(M, H, Ord)
>>> noise = NoNoise(meas)
>>> prep = SplitPoisson(1.0, M, H**2)
>>> recnet = LearnedPGD(noise, prep)
>>> recnet = PinvNet(noise, prep)
>>> x = torch.rand((B*C,2*M), dtype=torch.float)
>>> z = recnet.reconstruct(x)
>>> print(z.shape)
Expand All @@ -1098,6 +1332,14 @@ def reconstruct(self, x, exp=False):
# Measurement to image domain mapping
bc, _ = x.shape

# Compute the stepsize from the Lipschitz constant
if self.step_estimation:
self.stepsize_gd()

step = self.step
if not isinstance(step, torch.Tensor):
step = step.params

# Compute the stepsize from the Lipschitz constant
if self.step_estimation:
self.stepsize_gd()
Expand All @@ -1107,10 +1349,7 @@ def reconstruct(self, x, exp=False):
step = step.params

# Preprocessing in the measurement domain
if exp:
m, N0_est = self.prep.forward_expe(x, self.acqu.meas_op)
else:
m = self.prep(x) # shape x = [b*c, M]
m = self.prep(x) # shape x = [b*c, M]

if self.wls:
# Get variance of the measurements
Expand Down Expand Up @@ -1198,9 +1437,44 @@ def reconstruct(self, x, exp=False):
if self.res_learn:
# z=x-step*grad(L), x = P(z), x_end = z0 + P(z)
x = x + z0
return x

if exp:
x = self.prep.denormalize_expe(
x, N0_est, self.acqu.meas_op.h, self.acqu.meas_op.w
)
def reconstruct_expe(self, x):
r"""Reconstruction step of a reconstruction network
Same as :meth:`reconstruct` reconstruct except that:
1. The preprocessing step estimates the image intensity for normalization
2. The output images are "denormalized", i.e., have units of photon counts
Args:
:attr:`x`: raw measurement vectors
Shape:
:attr:`x`: :math:`(BC,2M)`
:attr:`output`: :math:`(BC,1,H,W)`
"""
# x of shape [b*c, 2M]
bc, _ = x.shape

# Preprocessing
x, N0_est = self.prep.forward_expe(x, self.acqu.meas_op) # shape x = [b*c, M]
print(N0_est)

# measurements to image domain processing
x = self.pinv(x, self.acqu.meas_op) # shape x = [b*c,N]

# Image domain denoising
x = x.view(
bc, 1, self.acqu.meas_op.h, self.acqu.meas_op.w
) # shape x = [b*c,1,h,w]
x = self.denoi(x) # shape x = [b*c,1,h,w]
print(x.max())

# Denormalization
x = self.prep.denormalize_expe(
x, N0_est, self.acqu.meas_op.h, self.acqu.meas_op.w
)
return x
Loading

0 comments on commit 57f5613

Please sign in to comment.