diff --git a/spyrit/core/recon.py b/spyrit/core/recon.py index ce577965..ebf496d5 100644 --- a/spyrit/core/recon.py +++ b/spyrit/core/recon.py @@ -887,3 +887,103 @@ def set_noise_level(self, noise_level): :attr:`output`: noise level tensor with shape :math:`(1)` """ self.noise_level = torch.FloatTensor([noise_level/255.]) + + +#%%=========================================================================================== +class PositiveParameters(nn.Module): +# =========================================================================================== + def __init__(self, size, val_min=1e-6): + super(PositiveParameters, self).__init__() + self.val_min = torch.tensor(val_min) + self.params = nn.Parameter(torch.abs(val_min*torch.ones(size,1)), requires_grad=True) + + def forward(self): + return torch.abs(self.params) + +#%%=========================================================================================== +class PositiveMonoIncreaseParameters(PositiveParameters): +# =========================================================================================== + def __init__(self, size, val_min=0.000001): + super().__init__(size, val_min) + + def forward(self): + # cumsum in opposite order + return super().forward().cumsum(dim=0).flip(dims=[0]) + +#%%=========================================================================================== +class UPGD(PinvNet): +# =========================================================================================== + def __init__(self, + noise, + prep, + denoi=nn.Identity(), + num_iter = 6, + lamb = 1e-5, + lamb_min = 1e-6, + split=False): + super(UPGD, self).__init__(noise, prep, denoi) + self.num_iter = num_iter + self.lamb = lamb + self.lamb_min = lamb_min + # Set a trainable tensor for the regularization parameter with dimension num_iter + # and constrained to be positive with clamp(min=0.0, max=None) + self.lambs = PositiveMonoIncreaseParameters(num_iter, lamb_min) # shape lambs = [num_iter,1] + #self.noise = noise + self.split = split + + def reconstruct(self, x): + r""" Reconstruction step of a reconstruction network + + Same as :meth:`reconstruct` reconstruct except that: + + 1. The regularization parameter is trainable + + Args: + :attr:`x`: raw measurement vectors + + Shape: + :attr:`x`: :math:`(BC,2M)` + + :attr:`output`: :math:`(BC,1,H,W)` + """ + + # Measurement operator + #if self.split: + # meas = super().Acq.meas_op + #else: + #meas = self.Acq.meas_op + meas = self.acqu.meas_op + + # x of shape [b*c, 2M] + bc, _ = x.shape + + # First estimate: Pseudo inverse + # Preprocessing in the measurement domain + x = self.prep(x) # [5, 1024] + + # Save measurements + m = x.clone() # [5, 1024] + + # measurements to image domain processing + x = self.pinv(x, self.acqu.meas_op) # [5, 4096] # shape x = [b*c,N] + #x = x.view(bc,1,self.acqu.meas_op.h, self.acqu.meas_op.w) # shape x = [b*c,1,h,w] + + # Unroll network + # Ensure step size is positive and monotonically decreasing and larger than self.lamb! + lambs = self.lambs() + for n in range(self.num_iter): + # Projection onto the measurement space + proj = self.acqu.meas_op.forward_H(x) # [5, 1024] + + # Residual + res = proj - m # [5, 1024] + + # Gradient step + x = x + lambs[n]*self.acqu.meas_op.H_adjoint(res) # [5, 4096] + + # Denoising step + x = x.view(bc,1,self.acqu.meas_op.h, self.acqu.meas_op.w) # [5, 1, 64, 64] + x = self.denoi(x) + x = x.view(bc, self.acqu.meas_op.N) # [5, 4096] + return x + \ No newline at end of file