Skip to content

Commit

Permalink
Merge pull request #104 from openspyrit/dcdrunet_new
Browse files Browse the repository at this point in the history
Dcdrunet new
  • Loading branch information
nducros authored Jul 13, 2023
2 parents f5a5897 + f909349 commit 19c1818
Show file tree
Hide file tree
Showing 7 changed files with 832 additions and 17 deletions.
146 changes: 145 additions & 1 deletion spyrit/core/recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,4 +742,148 @@ def reconstruct_expe(self, x):
# Denormalization
x = self.prep.denormalize_expe(x, norm, self.Acq.meas_op.h, self.Acq.meas_op.w)

return x
return x

#%%===========================================================================================
class DCDRUNet(DCNet):
# ===========================================================================================
r""" Denoised completion reconstruction network based on DRUNet wich concatenates a
noise level map to the input
.. math:
Args:
:attr:`noise`: Acquisition operator (see :class:`~spyrit.core.noise`)
:attr:`prep`: Preprocessing operator (see :class:`~spyrit.core.prep`)
:attr:`sigma`: UPDATE!! Tikhonov reconstruction operator of type
:class:`~spyrit.core.recon.TikhonovMeasurementPriorDiag()`
:attr:`denoi` (optional): Image denoising operator
(see :class:`~spyrit.core.nnet`).
Default :class:`~spyrit.core.nnet.Identity`
:attr:`noise_level` (optional): Noise level in the range [0, 255], default is noise_level=5
Input / Output:
:attr:`input`: Ground-truth images with concatenated noise level map with
shape :math:`(B,C+1,H,W)`
:attr:`output`: Reconstructed images with shape :math:`(B,C,H,W)`
Attributes:
:attr:`Acq`: Acquisition operator initialized as :attr:`noise`
:attr:`PreP`: Preprocessing operator initialized as :attr:`prep`
:attr:`DC_Layer`: Data consistency layer initialized as :attr:`tikho`
:attr:`Denoi`: Image (DRUNet architecture type) denoising operator
initialized as :attr:`denoi`
Example:
>>> B, C, H, M = 10, 1, 64, 64**2
>>> Ord = np.ones((H,H))
>>> meas = HadamSplit(M, H, Ord)
>>> noise = NoNoise(meas)
>>> prep = SplitPoisson(1.0, M, H*H)
>>> sigma = np.random.random([H**2, H**2])
>>> n_channels = 1 # 1 for grayscale image
>>> model_drunet_path = './spyrit/drunet/model_zoo/drunet_gray.pth'
>>> denoi_drunet = drunet(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R',
downsample_mode="strideconv", upsample_mode="convtranspose")
>>> recnet = DCDRUNet(noise,prep,sigma,denoi_drunet)
>>> z = recnet(x)
>>> print(z.shape)
torch.Size([10, 1, 64, 64])
"""

def __init__(self,
noise,
prep,
sigma,
denoi=nn.Identity(),
noise_level=5):
super().__init__(noise, prep, sigma, denoi)
self.register_buffer('noise_level', torch.FloatTensor([noise_level/255.]))

def reconstruct(self, x):
r""" Reconstruction step of a reconstruction network
Args:
:attr:`x`: raw measurement vectors
Shape:
:attr:`x`: raw measurement vectors with shape :math:`(BC,2M)`
:attr:`output`: reconstructed images with shape :math:`(BC,1,H,W)`
Example:
>>> B, C, H, M = 10, 1, 64, 64**2
>>> Ord = np.ones((H,H))
>>> meas = HadamSplit(M, H, Ord)
>>> noise = NoNoise(meas)
>>> prep = SplitPoisson(1.0, M, H*H)
>>> sigma = np.random.random([H**2, H**2])
>>> n_channels = 1 # 1 for grayscale image
>>> model_drunet_path = './spyrit/drunet/model_zoo/drunet_gray.pth'
>>> denoi_drunet = drunet(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R',
downsample_mode="strideconv", upsample_mode="convtranspose")
>>> recnet = DCDRUNet(noise,prep,sigma,denoi_drunet)
>>> x = torch.rand((B*C,2*M), dtype=torch.float)
>>> z = recnet.reconstruct(x)
>>> print(z.shape)
torch.Size([10, 1, 64, 64])
"""
# x of shape [b*c, 2M]

bc, _ = x.shape

# Preprocessing
var_noi = self.prep.sigma(x)
x = self.prep(x) # shape x = [b*c, M]

# measurements to image domain processing
x_0 = torch.zeros((bc, self.Acq.meas_op.N), device=x.device)
x = self.tikho(x, x_0, var_noi, self.Acq.meas_op)
x = x.view(bc,1,self.Acq.meas_op.h, self.Acq.meas_op.w) # shape x = [b*c,1,h,w]

# Image domain denoising
x = self.concat_noise_map(x)
x = self.denoi(x)

return x

def concat_noise_map(self, x):
r""" Concatenation of noise level map to reconstructed images
Args:
:attr:`x`: reconstructed images from the reconstruction layer
Shape:
:attr:`x`: reconstructed images with shape :math:`(BC,1,H,W)`
:attr:`output`: reconstructed images with concatenated noise level map with shape :math:`(BC,2,H,W)`
"""

b, c, h, w = x.shape
x = 0.5*(x + 1)
x = torch.cat((x, self.noise_level.expand(b, 1, h, w)), dim=1)
return x

def set_noise_level(self, noise_level):
r""" Reset noise level value
Args:
:attr:`noise_level`: noise level value in the range [0, 255]
Shape:
:attr:`noise_level`: float value noise level :math:`(1)`
:attr:`output`: noise level tensor with shape :math:`(1)`
"""
self.noise_level = torch.FloatTensor([noise_level/255.])
Empty file added spyrit/external/__init__.py
Empty file.
Loading

0 comments on commit 19c1818

Please sign in to comment.