Skip to content

Commit

Permalink
Merged conflict: Added modifs from UPGD on recon
Browse files Browse the repository at this point in the history
  • Loading branch information
jabascal committed Jul 13, 2023
1 parent 0983ad8 commit f75307f
Showing 1 changed file with 100 additions and 0 deletions.
100 changes: 100 additions & 0 deletions spyrit/core/recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,3 +887,103 @@ def set_noise_level(self, noise_level):
:attr:`output`: noise level tensor with shape :math:`(1)`
"""
self.noise_level = torch.FloatTensor([noise_level/255.])


#%%===========================================================================================
class PositiveParameters(nn.Module):
# ===========================================================================================
def __init__(self, size, val_min=1e-6):
super(PositiveParameters, self).__init__()
self.val_min = torch.tensor(val_min)
self.params = nn.Parameter(torch.abs(val_min*torch.ones(size,1)), requires_grad=True)

def forward(self):
return torch.abs(self.params)

#%%===========================================================================================
class PositiveMonoIncreaseParameters(PositiveParameters):
# ===========================================================================================
def __init__(self, size, val_min=0.000001):
super().__init__(size, val_min)

def forward(self):
# cumsum in opposite order
return super().forward().cumsum(dim=0).flip(dims=[0])

#%%===========================================================================================
class UPGD(PinvNet):
# ===========================================================================================
def __init__(self,
noise,
prep,
denoi=nn.Identity(),
num_iter = 6,
lamb = 1e-5,
lamb_min = 1e-6,
split=False):
super(UPGD, self).__init__(noise, prep, denoi)
self.num_iter = num_iter
self.lamb = lamb
self.lamb_min = lamb_min
# Set a trainable tensor for the regularization parameter with dimension num_iter
# and constrained to be positive with clamp(min=0.0, max=None)
self.lambs = PositiveMonoIncreaseParameters(num_iter, lamb_min) # shape lambs = [num_iter,1]
#self.noise = noise
self.split = split

def reconstruct(self, x):
r""" Reconstruction step of a reconstruction network
Same as :meth:`reconstruct` reconstruct except that:
1. The regularization parameter is trainable
Args:
:attr:`x`: raw measurement vectors
Shape:
:attr:`x`: :math:`(BC,2M)`
:attr:`output`: :math:`(BC,1,H,W)`
"""

# Measurement operator
#if self.split:
# meas = super().Acq.meas_op
#else:
#meas = self.Acq.meas_op
meas = self.acqu.meas_op

# x of shape [b*c, 2M]
bc, _ = x.shape

# First estimate: Pseudo inverse
# Preprocessing in the measurement domain
x = self.prep(x) # [5, 1024]

# Save measurements
m = x.clone() # [5, 1024]

# measurements to image domain processing
x = self.pinv(x, self.acqu.meas_op) # [5, 4096] # shape x = [b*c,N]
#x = x.view(bc,1,self.acqu.meas_op.h, self.acqu.meas_op.w) # shape x = [b*c,1,h,w]

# Unroll network
# Ensure step size is positive and monotonically decreasing and larger than self.lamb!
lambs = self.lambs()
for n in range(self.num_iter):
# Projection onto the measurement space
proj = self.acqu.meas_op.forward_H(x) # [5, 1024]

# Residual
res = proj - m # [5, 1024]

# Gradient step
x = x + lambs[n]*self.acqu.meas_op.H_adjoint(res) # [5, 4096]

# Denoising step
x = x.view(bc,1,self.acqu.meas_op.h, self.acqu.meas_op.w) # [5, 1, 64, 64]
x = self.denoi(x)
x = x.view(bc, self.acqu.meas_op.N) # [5, 4096]
return x

0 comments on commit f75307f

Please sign in to comment.