Skip to content

Commit 0983ad8

Browse files
committed
Resolve merge conflict: accepted and merged changes from both branches
2 parents da3a035 + 5eee416 commit 0983ad8

File tree

9 files changed

+839
-421
lines changed

9 files changed

+839
-421
lines changed

README.md

+5-2
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,19 @@ import spyrit
4646
```
4747

4848
## Examples
49-
Simple reconstruction examples can be found in [spyrit/test/tuto_core_2d.py](https://github.com/openspyrit/spyrit/blob/master/spyrit/test/tuto_core_2d.py).
49+
Simple reconstruction examples can be found in [tutorial](https://github.com/openspyrit/spyrit/blob/master/spyrit/tutorial).
50+
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openspyrit/spyrit/blob/master/spyrit/tutorial/tuto_core_2d_short.ipynb)
51+
5052

5153
# API Documentation
5254
https://spyrit.readthedocs.io/
5355

5456
# Contributors (alphabetical order)
57+
* Juan Abascal - [Website](https://juanabascal78.wixsite.com/juan-abascal-webpage)
5558
* Thomas Baudier
5659
* Sebastien Crombez
5760
* Nicolas Ducros - [Website](https://www.creatis.insa-lyon.fr/~ducros/WebPage/index.html)
58-
* Antonio Tomas Lorente Mur - [Website](https://www.creatis.insa-lyon.fr/~lorente/)
61+
* Antonio Tomas Lorente Mur - [Website]( https://sites.google.com/view/antonio-lorente-mur/)
5962
* Fadoua Taia-Alaoui
6063

6164
# How to cite?

spyrit/core/prep.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,8 @@ def sigma(self, x: torch.tensor) -> torch.tensor:
302302
:attr:`x`: batch of images in the Hadamard domain
303303
304304
Shape:
305-
- Input: :math:`(B,2*M)` where :math:`B` is the batch dimension
306-
- Output: :math:`(B, M)`
305+
- Input: :math:`(*,2*M)` :math:`*` indicates one or more dimensions
306+
- Output: :math:`(*, M)`
307307
308308
Example:
309309
>>> x = torch.rand([10,2*400], dtype=torch.float)
@@ -312,7 +312,7 @@ def sigma(self, x: torch.tensor) -> torch.tensor:
312312
torch.Size([10, 400])
313313
314314
"""
315-
x = x[:,self.even_index] + x[:,self.odd_index]
315+
x = x[...,self.even_index] + x[...,self.odd_index]
316316
x = 4*x/(self.alpha**2) # Cov is in [-1,1] so *4
317317
return x
318318

spyrit/core/recon.py

+120-83
Original file line numberDiff line numberDiff line change
@@ -744,109 +744,146 @@ def reconstruct_expe(self, x):
744744

745745
return x
746746

747-
class PositiveParameters(nn.Module):
748-
def __init__(self, size, val_min=1e-6):
749-
super(PositiveParameters, self).__init__()
750-
self.val_min = torch.tensor(val_min)
751-
self.params = nn.Parameter(torch.abs(val_min*torch.ones(size,1)), requires_grad=True)
747+
#%%===========================================================================================
748+
class DCDRUNet(DCNet):
749+
# ===========================================================================================
750+
r""" Denoised completion reconstruction network based on DRUNet wich concatenates a
751+
noise level map to the input
752+
753+
.. math:
752754
753-
def forward(self):
754-
return torch.abs(self.params)
755-
756-
class PositiveMonoIncreaseParameters(PositiveParameters):
757-
def __init__(self, size, val_min=0.000001):
758-
super().__init__(size, val_min)
759755
760-
def forward(self):
761-
# cumsum in opposite order
762-
return super().forward().cumsum(dim=0).flip(dims=[0])
756+
Args:
757+
:attr:`noise`: Acquisition operator (see :class:`~spyrit.core.noise`)
758+
759+
:attr:`prep`: Preprocessing operator (see :class:`~spyrit.core.prep`)
760+
761+
:attr:`sigma`: UPDATE!! Tikhonov reconstruction operator of type
762+
:class:`~spyrit.core.recon.TikhonovMeasurementPriorDiag()`
763+
764+
:attr:`denoi` (optional): Image denoising operator
765+
(see :class:`~spyrit.core.nnet`).
766+
Default :class:`~spyrit.core.nnet.Identity`
767+
768+
:attr:`noise_level` (optional): Noise level in the range [0, 255], default is noise_level=5
763769
764-
# class PositiveMonoDecreaseParameters(nn.Module):
765-
# def __init__(self, size, val_min=1e-6):
766-
# super(PositiveMonoDecreaseParameters, self).__init__()
767-
# self.val_min = torch.tensor(val_min)
768-
# self.params = nn.Parameter(torch.abs(val_min*torch.ones(size,1)), requires_grad=True)
769-
# self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
770+
771+
Input / Output:
772+
:attr:`input`: Ground-truth images with concatenated noise level map with
773+
shape :math:`(B,C+1,H,W)`
774+
775+
:attr:`output`: Reconstructed images with shape :math:`(B,C,H,W)`
776+
777+
Attributes:
778+
:attr:`Acq`: Acquisition operator initialized as :attr:`noise`
770779
771-
# def forward(self):
772-
# for i in range(1, len(self.params)):
773-
# self.params[i].data = torch.clamp(self.params[i].data, min=self.val_min.to(self.device), max=self.params[i-1].data)
774-
# return torch.abs(self.params)
780+
:attr:`PreP`: Preprocessing operator initialized as :attr:`prep`
781+
782+
:attr:`DC_Layer`: Data consistency layer initialized as :attr:`tikho`
783+
784+
:attr:`Denoi`: Image (DRUNet architecture type) denoising operator
785+
initialized as :attr:`denoi`
775786
787+
788+
Example:
789+
>>> B, C, H, M = 10, 1, 64, 64**2
790+
>>> Ord = np.ones((H,H))
791+
>>> meas = HadamSplit(M, H, Ord)
792+
>>> noise = NoNoise(meas)
793+
>>> prep = SplitPoisson(1.0, M, H*H)
794+
>>> sigma = np.random.random([H**2, H**2])
795+
>>> n_channels = 1 # 1 for grayscale image
796+
>>> model_drunet_path = './spyrit/drunet/model_zoo/drunet_gray.pth'
797+
>>> denoi_drunet = drunet(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R',
798+
downsample_mode="strideconv", upsample_mode="convtranspose")
799+
>>> recnet = DCDRUNet(noise,prep,sigma,denoi_drunet)
800+
>>> z = recnet(x)
801+
>>> print(z.shape)
802+
torch.Size([10, 1, 64, 64])
803+
"""
776804

777-
class UPGD(PinvNet):
778805
def __init__(self,
779806
noise,
780807
prep,
781-
denoi=nn.Identity(),
782-
num_iter = 6,
783-
lamb = 1e-5,
784-
lamb_min = 1e-6,
785-
split=False):
786-
super(UPGD, self).__init__(noise, prep, denoi)
787-
self.num_iter = num_iter
788-
self.lamb = lamb
789-
self.lamb_min = lamb_min
790-
# Set a trainable tensor for the regularization parameter with dimension num_iter
791-
# and constrained to be positive with clamp(min=0.0, max=None)
792-
self.lambs = PositiveMonoIncreaseParameters(num_iter, lamb_min) # shape lambs = [num_iter,1]
793-
#self.noise = noise
794-
self.split = split
795-
808+
sigma,
809+
denoi=nn.Identity(),
810+
noise_level=5):
811+
super().__init__(noise, prep, sigma, denoi)
812+
self.register_buffer('noise_level', torch.FloatTensor([noise_level/255.]))
813+
796814
def reconstruct(self, x):
797815
r""" Reconstruction step of a reconstruction network
798-
799-
Same as :meth:`reconstruct` reconstruct except that:
800-
801-
1. The regularization parameter is trainable
802816
803817
Args:
804818
:attr:`x`: raw measurement vectors
805819
806820
Shape:
807-
:attr:`x`: :math:`(BC,2M)`
821+
:attr:`x`: raw measurement vectors with shape :math:`(BC,2M)`
808822
809-
:attr:`output`: :math:`(BC,1,H,W)`
823+
:attr:`output`: reconstructed images with shape :math:`(BC,1,H,W)`
824+
825+
Example:
826+
>>> B, C, H, M = 10, 1, 64, 64**2
827+
>>> Ord = np.ones((H,H))
828+
>>> meas = HadamSplit(M, H, Ord)
829+
>>> noise = NoNoise(meas)
830+
>>> prep = SplitPoisson(1.0, M, H*H)
831+
>>> sigma = np.random.random([H**2, H**2])
832+
>>> n_channels = 1 # 1 for grayscale image
833+
>>> model_drunet_path = './spyrit/drunet/model_zoo/drunet_gray.pth'
834+
>>> denoi_drunet = drunet(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R',
835+
downsample_mode="strideconv", upsample_mode="convtranspose")
836+
>>> recnet = DCDRUNet(noise,prep,sigma,denoi_drunet)
837+
>>> x = torch.rand((B*C,2*M), dtype=torch.float)
838+
>>> z = recnet.reconstruct(x)
839+
>>> print(z.shape)
840+
torch.Size([10, 1, 64, 64])
810841
"""
811-
812-
# Measurement operator
813-
#if self.split:
814-
# meas = super().Acq.meas_op
815-
#else:
816-
#meas = self.Acq.meas_op
817-
meas = self.acqu.meas_op
818-
819842
# x of shape [b*c, 2M]
820-
bc, _ = x.shape
821-
822-
# First estimate: Pseudo inverse
823-
# Preprocessing in the measurement domain
824-
x = self.prep(x) # [5, 1024]
825-
826-
# Save measurements
827-
m = x.clone() # [5, 1024]
828-
843+
844+
bc, _ = x.shape
845+
846+
# Preprocessing
847+
var_noi = self.prep.sigma(x)
848+
x = self.prep(x) # shape x = [b*c, M]
849+
829850
# measurements to image domain processing
830-
x = self.pinv(x, self.acqu.meas_op) # [5, 4096] # shape x = [b*c,N]
831-
#x = x.view(bc,1,self.acqu.meas_op.h, self.acqu.meas_op.w) # shape x = [b*c,1,h,w]
832-
833-
# Unroll network
834-
# Ensure step size is positive and monotonically decreasing and larger than self.lamb!
835-
lambs = self.lambs()
836-
for n in range(self.num_iter):
837-
# Projection onto the measurement space
838-
proj = self.acqu.meas_op.forward_H(x) # [5, 1024]
839-
840-
# Residual
841-
res = proj - m # [5, 1024]
851+
x_0 = torch.zeros((bc, self.Acq.meas_op.N), device=x.device)
852+
x = self.tikho(x, x_0, var_noi, self.Acq.meas_op)
853+
x = x.view(bc,1,self.Acq.meas_op.h, self.Acq.meas_op.w) # shape x = [b*c,1,h,w]
854+
855+
# Image domain denoising
856+
x = self.concat_noise_map(x)
857+
x = self.denoi(x)
858+
859+
return x
842860

843-
# Gradient step
844-
x = x + lambs[n]*self.acqu.meas_op.H_adjoint(res) # [5, 4096]
861+
def concat_noise_map(self, x):
862+
r""" Concatenation of noise level map to reconstructed images
863+
864+
Args:
865+
:attr:`x`: reconstructed images from the reconstruction layer
866+
867+
Shape:
868+
:attr:`x`: reconstructed images with shape :math:`(BC,1,H,W)`
869+
870+
:attr:`output`: reconstructed images with concatenated noise level map with shape :math:`(BC,2,H,W)`
871+
"""
845872

846-
# Denoising step
847-
x = x.view(bc,1,self.acqu.meas_op.h, self.acqu.meas_op.w) # [5, 1, 64, 64]
848-
x = self.denoi(x)
849-
x = x.view(bc, self.acqu.meas_op.N) # [5, 4096]
850-
return x
851-
873+
b, c, h, w = x.shape
874+
x = 0.5*(x + 1)
875+
x = torch.cat((x, self.noise_level.expand(b, 1, h, w)), dim=1)
876+
return x
852877

878+
def set_noise_level(self, noise_level):
879+
r""" Reset noise level value
880+
881+
Args:
882+
:attr:`noise_level`: noise level value in the range [0, 255]
883+
884+
Shape:
885+
:attr:`noise_level`: float value noise level :math:`(1)`
886+
887+
:attr:`output`: noise level tensor with shape :math:`(1)`
888+
"""
889+
self.noise_level = torch.FloatTensor([noise_level/255.])

spyrit/external/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)