Skip to content

Commit

Permalink
Resolve merge conflict: accepted and merged changes from both branches
Browse files Browse the repository at this point in the history
  • Loading branch information
jabascal committed Jul 13, 2023
2 parents da3a035 + 5eee416 commit 0983ad8
Show file tree
Hide file tree
Showing 9 changed files with 839 additions and 421 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,19 @@ import spyrit
```

## Examples
Simple reconstruction examples can be found in [spyrit/test/tuto_core_2d.py](https://github.com/openspyrit/spyrit/blob/master/spyrit/test/tuto_core_2d.py).
Simple reconstruction examples can be found in [tutorial](https://github.com/openspyrit/spyrit/blob/master/spyrit/tutorial).
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openspyrit/spyrit/blob/master/spyrit/tutorial/tuto_core_2d_short.ipynb)


# API Documentation
https://spyrit.readthedocs.io/

# Contributors (alphabetical order)
* Juan Abascal - [Website](https://juanabascal78.wixsite.com/juan-abascal-webpage)
* Thomas Baudier
* Sebastien Crombez
* Nicolas Ducros - [Website](https://www.creatis.insa-lyon.fr/~ducros/WebPage/index.html)
* Antonio Tomas Lorente Mur - [Website](https://www.creatis.insa-lyon.fr/~lorente/)
* Antonio Tomas Lorente Mur - [Website]( https://sites.google.com/view/antonio-lorente-mur/)
* Fadoua Taia-Alaoui

# How to cite?
Expand Down
6 changes: 3 additions & 3 deletions spyrit/core/prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,8 @@ def sigma(self, x: torch.tensor) -> torch.tensor:
:attr:`x`: batch of images in the Hadamard domain
Shape:
- Input: :math:`(B,2*M)` where :math:`B` is the batch dimension
- Output: :math:`(B, M)`
- Input: :math:`(*,2*M)` :math:`*` indicates one or more dimensions
- Output: :math:`(*, M)`
Example:
>>> x = torch.rand([10,2*400], dtype=torch.float)
Expand All @@ -312,7 +312,7 @@ def sigma(self, x: torch.tensor) -> torch.tensor:
torch.Size([10, 400])
"""
x = x[:,self.even_index] + x[:,self.odd_index]
x = x[...,self.even_index] + x[...,self.odd_index]
x = 4*x/(self.alpha**2) # Cov is in [-1,1] so *4
return x

Expand Down
203 changes: 120 additions & 83 deletions spyrit/core/recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,109 +744,146 @@ def reconstruct_expe(self, x):

return x

class PositiveParameters(nn.Module):
def __init__(self, size, val_min=1e-6):
super(PositiveParameters, self).__init__()
self.val_min = torch.tensor(val_min)
self.params = nn.Parameter(torch.abs(val_min*torch.ones(size,1)), requires_grad=True)
#%%===========================================================================================
class DCDRUNet(DCNet):
# ===========================================================================================
r""" Denoised completion reconstruction network based on DRUNet wich concatenates a
noise level map to the input
.. math:
def forward(self):
return torch.abs(self.params)

class PositiveMonoIncreaseParameters(PositiveParameters):
def __init__(self, size, val_min=0.000001):
super().__init__(size, val_min)
def forward(self):
# cumsum in opposite order
return super().forward().cumsum(dim=0).flip(dims=[0])
Args:
:attr:`noise`: Acquisition operator (see :class:`~spyrit.core.noise`)
:attr:`prep`: Preprocessing operator (see :class:`~spyrit.core.prep`)
:attr:`sigma`: UPDATE!! Tikhonov reconstruction operator of type
:class:`~spyrit.core.recon.TikhonovMeasurementPriorDiag()`
:attr:`denoi` (optional): Image denoising operator
(see :class:`~spyrit.core.nnet`).
Default :class:`~spyrit.core.nnet.Identity`
:attr:`noise_level` (optional): Noise level in the range [0, 255], default is noise_level=5
# class PositiveMonoDecreaseParameters(nn.Module):
# def __init__(self, size, val_min=1e-6):
# super(PositiveMonoDecreaseParameters, self).__init__()
# self.val_min = torch.tensor(val_min)
# self.params = nn.Parameter(torch.abs(val_min*torch.ones(size,1)), requires_grad=True)
# self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Input / Output:
:attr:`input`: Ground-truth images with concatenated noise level map with
shape :math:`(B,C+1,H,W)`
:attr:`output`: Reconstructed images with shape :math:`(B,C,H,W)`
Attributes:
:attr:`Acq`: Acquisition operator initialized as :attr:`noise`
# def forward(self):
# for i in range(1, len(self.params)):
# self.params[i].data = torch.clamp(self.params[i].data, min=self.val_min.to(self.device), max=self.params[i-1].data)
# return torch.abs(self.params)
:attr:`PreP`: Preprocessing operator initialized as :attr:`prep`
:attr:`DC_Layer`: Data consistency layer initialized as :attr:`tikho`
:attr:`Denoi`: Image (DRUNet architecture type) denoising operator
initialized as :attr:`denoi`
Example:
>>> B, C, H, M = 10, 1, 64, 64**2
>>> Ord = np.ones((H,H))
>>> meas = HadamSplit(M, H, Ord)
>>> noise = NoNoise(meas)
>>> prep = SplitPoisson(1.0, M, H*H)
>>> sigma = np.random.random([H**2, H**2])
>>> n_channels = 1 # 1 for grayscale image
>>> model_drunet_path = './spyrit/drunet/model_zoo/drunet_gray.pth'
>>> denoi_drunet = drunet(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R',
downsample_mode="strideconv", upsample_mode="convtranspose")
>>> recnet = DCDRUNet(noise,prep,sigma,denoi_drunet)
>>> z = recnet(x)
>>> print(z.shape)
torch.Size([10, 1, 64, 64])
"""

class UPGD(PinvNet):
def __init__(self,
noise,
prep,
denoi=nn.Identity(),
num_iter = 6,
lamb = 1e-5,
lamb_min = 1e-6,
split=False):
super(UPGD, self).__init__(noise, prep, denoi)
self.num_iter = num_iter
self.lamb = lamb
self.lamb_min = lamb_min
# Set a trainable tensor for the regularization parameter with dimension num_iter
# and constrained to be positive with clamp(min=0.0, max=None)
self.lambs = PositiveMonoIncreaseParameters(num_iter, lamb_min) # shape lambs = [num_iter,1]
#self.noise = noise
self.split = split

sigma,
denoi=nn.Identity(),
noise_level=5):
super().__init__(noise, prep, sigma, denoi)
self.register_buffer('noise_level', torch.FloatTensor([noise_level/255.]))

def reconstruct(self, x):
r""" Reconstruction step of a reconstruction network
Same as :meth:`reconstruct` reconstruct except that:
1. The regularization parameter is trainable
Args:
:attr:`x`: raw measurement vectors
Shape:
:attr:`x`: :math:`(BC,2M)`
:attr:`x`: raw measurement vectors with shape :math:`(BC,2M)`
:attr:`output`: :math:`(BC,1,H,W)`
:attr:`output`: reconstructed images with shape :math:`(BC,1,H,W)`
Example:
>>> B, C, H, M = 10, 1, 64, 64**2
>>> Ord = np.ones((H,H))
>>> meas = HadamSplit(M, H, Ord)
>>> noise = NoNoise(meas)
>>> prep = SplitPoisson(1.0, M, H*H)
>>> sigma = np.random.random([H**2, H**2])
>>> n_channels = 1 # 1 for grayscale image
>>> model_drunet_path = './spyrit/drunet/model_zoo/drunet_gray.pth'
>>> denoi_drunet = drunet(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R',
downsample_mode="strideconv", upsample_mode="convtranspose")
>>> recnet = DCDRUNet(noise,prep,sigma,denoi_drunet)
>>> x = torch.rand((B*C,2*M), dtype=torch.float)
>>> z = recnet.reconstruct(x)
>>> print(z.shape)
torch.Size([10, 1, 64, 64])
"""

# Measurement operator
#if self.split:
# meas = super().Acq.meas_op
#else:
#meas = self.Acq.meas_op
meas = self.acqu.meas_op

# x of shape [b*c, 2M]
bc, _ = x.shape

# First estimate: Pseudo inverse
# Preprocessing in the measurement domain
x = self.prep(x) # [5, 1024]

# Save measurements
m = x.clone() # [5, 1024]


bc, _ = x.shape

# Preprocessing
var_noi = self.prep.sigma(x)
x = self.prep(x) # shape x = [b*c, M]

# measurements to image domain processing
x = self.pinv(x, self.acqu.meas_op) # [5, 4096] # shape x = [b*c,N]
#x = x.view(bc,1,self.acqu.meas_op.h, self.acqu.meas_op.w) # shape x = [b*c,1,h,w]

# Unroll network
# Ensure step size is positive and monotonically decreasing and larger than self.lamb!
lambs = self.lambs()
for n in range(self.num_iter):
# Projection onto the measurement space
proj = self.acqu.meas_op.forward_H(x) # [5, 1024]

# Residual
res = proj - m # [5, 1024]
x_0 = torch.zeros((bc, self.Acq.meas_op.N), device=x.device)
x = self.tikho(x, x_0, var_noi, self.Acq.meas_op)
x = x.view(bc,1,self.Acq.meas_op.h, self.Acq.meas_op.w) # shape x = [b*c,1,h,w]

# Image domain denoising
x = self.concat_noise_map(x)
x = self.denoi(x)

return x

# Gradient step
x = x + lambs[n]*self.acqu.meas_op.H_adjoint(res) # [5, 4096]
def concat_noise_map(self, x):
r""" Concatenation of noise level map to reconstructed images
Args:
:attr:`x`: reconstructed images from the reconstruction layer
Shape:
:attr:`x`: reconstructed images with shape :math:`(BC,1,H,W)`
:attr:`output`: reconstructed images with concatenated noise level map with shape :math:`(BC,2,H,W)`
"""

# Denoising step
x = x.view(bc,1,self.acqu.meas_op.h, self.acqu.meas_op.w) # [5, 1, 64, 64]
x = self.denoi(x)
x = x.view(bc, self.acqu.meas_op.N) # [5, 4096]
return x

b, c, h, w = x.shape
x = 0.5*(x + 1)
x = torch.cat((x, self.noise_level.expand(b, 1, h, w)), dim=1)
return x

def set_noise_level(self, noise_level):
r""" Reset noise level value
Args:
:attr:`noise_level`: noise level value in the range [0, 255]
Shape:
:attr:`noise_level`: float value noise level :math:`(1)`
:attr:`output`: noise level tensor with shape :math:`(1)`
"""
self.noise_level = torch.FloatTensor([noise_level/255.])
Empty file added spyrit/external/__init__.py
Empty file.
Loading

0 comments on commit 0983ad8

Please sign in to comment.