@@ -744,109 +744,146 @@ def reconstruct_expe(self, x):
744
744
745
745
return x
746
746
747
- class PositiveParameters (nn .Module ):
748
- def __init__ (self , size , val_min = 1e-6 ):
749
- super (PositiveParameters , self ).__init__ ()
750
- self .val_min = torch .tensor (val_min )
751
- self .params = nn .Parameter (torch .abs (val_min * torch .ones (size ,1 )), requires_grad = True )
747
+ #%%===========================================================================================
748
+ class DCDRUNet (DCNet ):
749
+ # ===========================================================================================
750
+ r""" Denoised completion reconstruction network based on DRUNet wich concatenates a
751
+ noise level map to the input
752
+
753
+ .. math:
752
754
753
- def forward (self ):
754
- return torch .abs (self .params )
755
-
756
- class PositiveMonoIncreaseParameters (PositiveParameters ):
757
- def __init__ (self , size , val_min = 0.000001 ):
758
- super ().__init__ (size , val_min )
759
755
760
- def forward (self ):
761
- # cumsum in opposite order
762
- return super ().forward ().cumsum (dim = 0 ).flip (dims = [0 ])
756
+ Args:
757
+ :attr:`noise`: Acquisition operator (see :class:`~spyrit.core.noise`)
758
+
759
+ :attr:`prep`: Preprocessing operator (see :class:`~spyrit.core.prep`)
760
+
761
+ :attr:`sigma`: UPDATE!! Tikhonov reconstruction operator of type
762
+ :class:`~spyrit.core.recon.TikhonovMeasurementPriorDiag()`
763
+
764
+ :attr:`denoi` (optional): Image denoising operator
765
+ (see :class:`~spyrit.core.nnet`).
766
+ Default :class:`~spyrit.core.nnet.Identity`
767
+
768
+ :attr:`noise_level` (optional): Noise level in the range [0, 255], default is noise_level=5
763
769
764
- # class PositiveMonoDecreaseParameters(nn.Module):
765
- # def __init__(self, size, val_min=1e-6):
766
- # super(PositiveMonoDecreaseParameters, self).__init__()
767
- # self.val_min = torch.tensor(val_min)
768
- # self.params = nn.Parameter(torch.abs(val_min*torch.ones(size,1)), requires_grad=True)
769
- # self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
770
+
771
+ Input / Output:
772
+ :attr:`input`: Ground-truth images with concatenated noise level map with
773
+ shape :math:`(B,C+1,H,W)`
774
+
775
+ :attr:`output`: Reconstructed images with shape :math:`(B,C,H,W)`
776
+
777
+ Attributes:
778
+ :attr:`Acq`: Acquisition operator initialized as :attr:`noise`
770
779
771
- # def forward(self):
772
- # for i in range(1, len(self.params)):
773
- # self.params[i].data = torch.clamp(self.params[i].data, min=self.val_min.to(self.device), max=self.params[i-1].data)
774
- # return torch.abs(self.params)
780
+ :attr:`PreP`: Preprocessing operator initialized as :attr:`prep`
781
+
782
+ :attr:`DC_Layer`: Data consistency layer initialized as :attr:`tikho`
783
+
784
+ :attr:`Denoi`: Image (DRUNet architecture type) denoising operator
785
+ initialized as :attr:`denoi`
775
786
787
+
788
+ Example:
789
+ >>> B, C, H, M = 10, 1, 64, 64**2
790
+ >>> Ord = np.ones((H,H))
791
+ >>> meas = HadamSplit(M, H, Ord)
792
+ >>> noise = NoNoise(meas)
793
+ >>> prep = SplitPoisson(1.0, M, H*H)
794
+ >>> sigma = np.random.random([H**2, H**2])
795
+ >>> n_channels = 1 # 1 for grayscale image
796
+ >>> model_drunet_path = './spyrit/drunet/model_zoo/drunet_gray.pth'
797
+ >>> denoi_drunet = drunet(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R',
798
+ downsample_mode="strideconv", upsample_mode="convtranspose")
799
+ >>> recnet = DCDRUNet(noise,prep,sigma,denoi_drunet)
800
+ >>> z = recnet(x)
801
+ >>> print(z.shape)
802
+ torch.Size([10, 1, 64, 64])
803
+ """
776
804
777
- class UPGD (PinvNet ):
778
805
def __init__ (self ,
779
806
noise ,
780
807
prep ,
781
- denoi = nn .Identity (),
782
- num_iter = 6 ,
783
- lamb = 1e-5 ,
784
- lamb_min = 1e-6 ,
785
- split = False ):
786
- super (UPGD , self ).__init__ (noise , prep , denoi )
787
- self .num_iter = num_iter
788
- self .lamb = lamb
789
- self .lamb_min = lamb_min
790
- # Set a trainable tensor for the regularization parameter with dimension num_iter
791
- # and constrained to be positive with clamp(min=0.0, max=None)
792
- self .lambs = PositiveMonoIncreaseParameters (num_iter , lamb_min ) # shape lambs = [num_iter,1]
793
- #self.noise = noise
794
- self .split = split
795
-
808
+ sigma ,
809
+ denoi = nn .Identity (),
810
+ noise_level = 5 ):
811
+ super ().__init__ (noise , prep , sigma , denoi )
812
+ self .register_buffer ('noise_level' , torch .FloatTensor ([noise_level / 255. ]))
813
+
796
814
def reconstruct (self , x ):
797
815
r""" Reconstruction step of a reconstruction network
798
-
799
- Same as :meth:`reconstruct` reconstruct except that:
800
-
801
- 1. The regularization parameter is trainable
802
816
803
817
Args:
804
818
:attr:`x`: raw measurement vectors
805
819
806
820
Shape:
807
- :attr:`x`: :math:`(BC,2M)`
821
+ :attr:`x`: raw measurement vectors with shape :math:`(BC,2M)`
808
822
809
- :attr:`output`: :math:`(BC,1,H,W)`
823
+ :attr:`output`: reconstructed images with shape :math:`(BC,1,H,W)`
824
+
825
+ Example:
826
+ >>> B, C, H, M = 10, 1, 64, 64**2
827
+ >>> Ord = np.ones((H,H))
828
+ >>> meas = HadamSplit(M, H, Ord)
829
+ >>> noise = NoNoise(meas)
830
+ >>> prep = SplitPoisson(1.0, M, H*H)
831
+ >>> sigma = np.random.random([H**2, H**2])
832
+ >>> n_channels = 1 # 1 for grayscale image
833
+ >>> model_drunet_path = './spyrit/drunet/model_zoo/drunet_gray.pth'
834
+ >>> denoi_drunet = drunet(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R',
835
+ downsample_mode="strideconv", upsample_mode="convtranspose")
836
+ >>> recnet = DCDRUNet(noise,prep,sigma,denoi_drunet)
837
+ >>> x = torch.rand((B*C,2*M), dtype=torch.float)
838
+ >>> z = recnet.reconstruct(x)
839
+ >>> print(z.shape)
840
+ torch.Size([10, 1, 64, 64])
810
841
"""
811
-
812
- # Measurement operator
813
- #if self.split:
814
- # meas = super().Acq.meas_op
815
- #else:
816
- #meas = self.Acq.meas_op
817
- meas = self .acqu .meas_op
818
-
819
842
# x of shape [b*c, 2M]
820
- bc , _ = x .shape
821
-
822
- # First estimate: Pseudo inverse
823
- # Preprocessing in the measurement domain
824
- x = self .prep (x ) # [5, 1024]
825
-
826
- # Save measurements
827
- m = x .clone () # [5, 1024]
828
-
843
+
844
+ bc , _ = x .shape
845
+
846
+ # Preprocessing
847
+ var_noi = self .prep .sigma (x )
848
+ x = self .prep (x ) # shape x = [b*c, M]
849
+
829
850
# measurements to image domain processing
830
- x = self .pinv (x , self .acqu .meas_op ) # [5, 4096] # shape x = [b*c,N]
831
- #x = x.view(bc,1,self.acqu.meas_op.h, self.acqu.meas_op.w) # shape x = [b*c,1,h,w]
832
-
833
- # Unroll network
834
- # Ensure step size is positive and monotonically decreasing and larger than self.lamb!
835
- lambs = self .lambs ()
836
- for n in range (self .num_iter ):
837
- # Projection onto the measurement space
838
- proj = self .acqu .meas_op .forward_H (x ) # [5, 1024]
839
-
840
- # Residual
841
- res = proj - m # [5, 1024]
851
+ x_0 = torch .zeros ((bc , self .Acq .meas_op .N ), device = x .device )
852
+ x = self .tikho (x , x_0 , var_noi , self .Acq .meas_op )
853
+ x = x .view (bc ,1 ,self .Acq .meas_op .h , self .Acq .meas_op .w ) # shape x = [b*c,1,h,w]
854
+
855
+ # Image domain denoising
856
+ x = self .concat_noise_map (x )
857
+ x = self .denoi (x )
858
+
859
+ return x
842
860
843
- # Gradient step
844
- x = x + lambs [n ]* self .acqu .meas_op .H_adjoint (res ) # [5, 4096]
861
+ def concat_noise_map (self , x ):
862
+ r""" Concatenation of noise level map to reconstructed images
863
+
864
+ Args:
865
+ :attr:`x`: reconstructed images from the reconstruction layer
866
+
867
+ Shape:
868
+ :attr:`x`: reconstructed images with shape :math:`(BC,1,H,W)`
869
+
870
+ :attr:`output`: reconstructed images with concatenated noise level map with shape :math:`(BC,2,H,W)`
871
+ """
845
872
846
- # Denoising step
847
- x = x .view (bc ,1 ,self .acqu .meas_op .h , self .acqu .meas_op .w ) # [5, 1, 64, 64]
848
- x = self .denoi (x )
849
- x = x .view (bc , self .acqu .meas_op .N ) # [5, 4096]
850
- return x
851
-
873
+ b , c , h , w = x .shape
874
+ x = 0.5 * (x + 1 )
875
+ x = torch .cat ((x , self .noise_level .expand (b , 1 , h , w )), dim = 1 )
876
+ return x
852
877
878
+ def set_noise_level (self , noise_level ):
879
+ r""" Reset noise level value
880
+
881
+ Args:
882
+ :attr:`noise_level`: noise level value in the range [0, 255]
883
+
884
+ Shape:
885
+ :attr:`noise_level`: float value noise level :math:`(1)`
886
+
887
+ :attr:`output`: noise level tensor with shape :math:`(1)`
888
+ """
889
+ self .noise_level = torch .FloatTensor ([noise_level / 255. ])
0 commit comments