diff --git a/spyrit/core/recon.py b/spyrit/core/recon.py index 4893c1ed..ce577965 100644 --- a/spyrit/core/recon.py +++ b/spyrit/core/recon.py @@ -742,4 +742,148 @@ def reconstruct_expe(self, x): # Denormalization x = self.prep.denormalize_expe(x, norm, self.Acq.meas_op.h, self.Acq.meas_op.w) - return x \ No newline at end of file + return x + +#%%=========================================================================================== +class DCDRUNet(DCNet): +# =========================================================================================== + r""" Denoised completion reconstruction network based on DRUNet wich concatenates a + noise level map to the input + + .. math: + + + Args: + :attr:`noise`: Acquisition operator (see :class:`~spyrit.core.noise`) + + :attr:`prep`: Preprocessing operator (see :class:`~spyrit.core.prep`) + + :attr:`sigma`: UPDATE!! Tikhonov reconstruction operator of type + :class:`~spyrit.core.recon.TikhonovMeasurementPriorDiag()` + + :attr:`denoi` (optional): Image denoising operator + (see :class:`~spyrit.core.nnet`). + Default :class:`~spyrit.core.nnet.Identity` + + :attr:`noise_level` (optional): Noise level in the range [0, 255], default is noise_level=5 + + + Input / Output: + :attr:`input`: Ground-truth images with concatenated noise level map with + shape :math:`(B,C+1,H,W)` + + :attr:`output`: Reconstructed images with shape :math:`(B,C,H,W)` + + Attributes: + :attr:`Acq`: Acquisition operator initialized as :attr:`noise` + + :attr:`PreP`: Preprocessing operator initialized as :attr:`prep` + + :attr:`DC_Layer`: Data consistency layer initialized as :attr:`tikho` + + :attr:`Denoi`: Image (DRUNet architecture type) denoising operator + initialized as :attr:`denoi` + + + Example: + >>> B, C, H, M = 10, 1, 64, 64**2 + >>> Ord = np.ones((H,H)) + >>> meas = HadamSplit(M, H, Ord) + >>> noise = NoNoise(meas) + >>> prep = SplitPoisson(1.0, M, H*H) + >>> sigma = np.random.random([H**2, H**2]) + >>> n_channels = 1 # 1 for grayscale image + >>> model_drunet_path = './spyrit/drunet/model_zoo/drunet_gray.pth' + >>> denoi_drunet = drunet(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R', + downsample_mode="strideconv", upsample_mode="convtranspose") + >>> recnet = DCDRUNet(noise,prep,sigma,denoi_drunet) + >>> z = recnet(x) + >>> print(z.shape) + torch.Size([10, 1, 64, 64]) + """ + + def __init__(self, + noise, + prep, + sigma, + denoi=nn.Identity(), + noise_level=5): + super().__init__(noise, prep, sigma, denoi) + self.register_buffer('noise_level', torch.FloatTensor([noise_level/255.])) + + def reconstruct(self, x): + r""" Reconstruction step of a reconstruction network + + Args: + :attr:`x`: raw measurement vectors + + Shape: + :attr:`x`: raw measurement vectors with shape :math:`(BC,2M)` + + :attr:`output`: reconstructed images with shape :math:`(BC,1,H,W)` + + Example: + >>> B, C, H, M = 10, 1, 64, 64**2 + >>> Ord = np.ones((H,H)) + >>> meas = HadamSplit(M, H, Ord) + >>> noise = NoNoise(meas) + >>> prep = SplitPoisson(1.0, M, H*H) + >>> sigma = np.random.random([H**2, H**2]) + >>> n_channels = 1 # 1 for grayscale image + >>> model_drunet_path = './spyrit/drunet/model_zoo/drunet_gray.pth' + >>> denoi_drunet = drunet(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R', + downsample_mode="strideconv", upsample_mode="convtranspose") + >>> recnet = DCDRUNet(noise,prep,sigma,denoi_drunet) + >>> x = torch.rand((B*C,2*M), dtype=torch.float) + >>> z = recnet.reconstruct(x) + >>> print(z.shape) + torch.Size([10, 1, 64, 64]) + """ + # x of shape [b*c, 2M] + + bc, _ = x.shape + + # Preprocessing + var_noi = self.prep.sigma(x) + x = self.prep(x) # shape x = [b*c, M] + + # measurements to image domain processing + x_0 = torch.zeros((bc, self.Acq.meas_op.N), device=x.device) + x = self.tikho(x, x_0, var_noi, self.Acq.meas_op) + x = x.view(bc,1,self.Acq.meas_op.h, self.Acq.meas_op.w) # shape x = [b*c,1,h,w] + + # Image domain denoising + x = self.concat_noise_map(x) + x = self.denoi(x) + + return x + + def concat_noise_map(self, x): + r""" Concatenation of noise level map to reconstructed images + + Args: + :attr:`x`: reconstructed images from the reconstruction layer + + Shape: + :attr:`x`: reconstructed images with shape :math:`(BC,1,H,W)` + + :attr:`output`: reconstructed images with concatenated noise level map with shape :math:`(BC,2,H,W)` + """ + + b, c, h, w = x.shape + x = 0.5*(x + 1) + x = torch.cat((x, self.noise_level.expand(b, 1, h, w)), dim=1) + return x + + def set_noise_level(self, noise_level): + r""" Reset noise level value + + Args: + :attr:`noise_level`: noise level value in the range [0, 255] + + Shape: + :attr:`noise_level`: float value noise level :math:`(1)` + + :attr:`output`: noise level tensor with shape :math:`(1)` + """ + self.noise_level = torch.FloatTensor([noise_level/255.]) diff --git a/spyrit/external/__init__.py b/spyrit/external/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/spyrit/external/drunet.py b/spyrit/external/drunet.py new file mode 100644 index 00000000..41b71db3 --- /dev/null +++ b/spyrit/external/drunet.py @@ -0,0 +1,369 @@ +import torch +import torch.nn as nn +import numpy as np +from collections import OrderedDict + +''' + Modified by J Abascal https://github.com/cszn/DPIR/blob/master/models/network_unet.py + June 2023 + Plug-and-Play Image Restoration with Deep Denoiser Prior +''' + + +class UNetRes(nn.Module): + def __init__(self, in_nc=1, out_nc=1, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode='strideconv', upsample_mode='convtranspose'): + super(UNetRes, self).__init__() + + self.m_head = conv(in_nc, nc[0], bias=False, mode='C') + + # downsample + if downsample_mode == 'avgpool': + downsample_block = downsample_avgpool + elif downsample_mode == 'maxpool': + downsample_block = downsample_maxpool + elif downsample_mode == 'strideconv': + downsample_block = downsample_strideconv + else: + raise NotImplementedError('downsample mode [{:s}] is not found'.format(downsample_mode)) + + self.m_down1 = sequential(*[ResBlock(nc[0], nc[0], bias=False, mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[0], nc[1], bias=False, mode='2')) + self.m_down2 = sequential(*[ResBlock(nc[1], nc[1], bias=False, mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[1], nc[2], bias=False, mode='2')) + self.m_down3 = sequential(*[ResBlock(nc[2], nc[2], bias=False, mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[2], nc[3], bias=False, mode='2')) + + self.m_body = sequential(*[ResBlock(nc[3], nc[3], bias=False, mode='C'+act_mode+'C') for _ in range(nb)]) + + # upsample + if upsample_mode == 'upconv': + upsample_block = upsample_upconv + elif upsample_mode == 'pixelshuffle': + upsample_block = upsample_pixelshuffle + elif upsample_mode == 'convtranspose': + upsample_block = upsample_convtranspose + else: + raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) + + self.m_up3 = sequential(upsample_block(nc[3], nc[2], bias=False, mode='2'), *[ResBlock(nc[2], nc[2], bias=False, mode='C'+act_mode+'C') for _ in range(nb)]) + self.m_up2 = sequential(upsample_block(nc[2], nc[1], bias=False, mode='2'), *[ResBlock(nc[1], nc[1], bias=False, mode='C'+act_mode+'C') for _ in range(nb)]) + self.m_up1 = sequential(upsample_block(nc[1], nc[0], bias=False, mode='2'), *[ResBlock(nc[0], nc[0], bias=False, mode='C'+act_mode+'C') for _ in range(nb)]) + + self.m_tail = conv(nc[0], out_nc, bias=False, mode='C') + + def forward(self, x0): + x1 = self.m_head(x0) + x2 = self.m_down1(x1) + x3 = self.m_down2(x2) + x4 = self.m_down3(x3) + x = self.m_body(x4) + x = self.m_up3(x+x4) + x = self.m_up2(x+x3) + x = self.m_up1(x+x2) + x = self.m_tail(x+x1) + + return x + + + +# ---------------------------------------------- +# Functions taken from basicblock.py +# https://github.com/cszn/DPIR/tree/master/models +# ---------------------------------------------- + + +# -------------------------------------------- +# Res Block: x + conv(relu(conv(x))) +# -------------------------------------------- +class ResBlock(nn.Module): + def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CRC', negative_slope=0.2): + super(ResBlock, self).__init__() + + assert in_channels == out_channels, 'Only support in_channels==out_channels.' + if mode[0] in ['R', 'L']: + mode = mode[0].lower() + mode[1:] + + self.res = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope) + + def forward(self, x): + #res = self.res(x) + return x + self.res(x) + + + +""" +# -------------------------------------------- +# Upsampler +# Kai Zhang, https://github.com/cszn/KAIR +# -------------------------------------------- +# upsample_pixelshuffle +# upsample_upconv +# upsample_convtranspose +# -------------------------------------------- +""" + + +# -------------------------------------------- +# conv + subp (+ relu) +# -------------------------------------------- +def upsample_pixelshuffle(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', negative_slope=0.2): + assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.' + up1 = conv(in_channels, out_channels * (int(mode[0]) ** 2), kernel_size, stride, padding, bias, mode='C'+mode, negative_slope=negative_slope) + return up1 + + +# -------------------------------------------- +# nearest_upsample + conv (+ R) +# -------------------------------------------- +def upsample_upconv(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', negative_slope=0.2): + assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR' + if mode[0] == '2': + uc = 'UC' + elif mode[0] == '3': + uc = 'uC' + elif mode[0] == '4': + uc = 'vC' + mode = mode.replace(mode[0], uc) + up1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode, negative_slope=negative_slope) + return up1 + + +# -------------------------------------------- +# convTranspose (+ relu) +# -------------------------------------------- +def upsample_convtranspose(in_channels=64, out_channels=3, kernel_size=2, stride=2, padding=0, bias=True, mode='2R', negative_slope=0.2): + assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.' + kernel_size = int(mode[0]) + stride = int(mode[0]) + mode = mode.replace(mode[0], 'T') + up1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope) + return up1 + + +''' +# -------------------------------------------- +# Downsampler +# Kai Zhang, https://github.com/cszn/KAIR +# -------------------------------------------- +# downsample_strideconv +# downsample_maxpool +# downsample_avgpool +# -------------------------------------------- +''' + + +# -------------------------------------------- +# strideconv (+ relu) +# -------------------------------------------- +def downsample_strideconv(in_channels=64, out_channels=64, kernel_size=2, stride=2, padding=0, bias=True, mode='2R', negative_slope=0.2): + assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.' + kernel_size = int(mode[0]) + stride = int(mode[0]) + mode = mode.replace(mode[0], 'C') + down1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope) + return down1 + + +# -------------------------------------------- +# maxpooling + conv (+ relu) +# -------------------------------------------- +def downsample_maxpool(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0, bias=True, mode='2R', negative_slope=0.2): + assert len(mode)<4 and mode[0] in ['2', '3'], 'mode examples: 2, 2R, 2BR, 3, ..., 3BR.' + kernel_size_pool = int(mode[0]) + stride_pool = int(mode[0]) + mode = mode.replace(mode[0], 'MC') + pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0], negative_slope=negative_slope) + pool_tail = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode[1:], negative_slope=negative_slope) + return sequential(pool, pool_tail) + + +# -------------------------------------------- +# averagepooling + conv (+ relu) +# -------------------------------------------- +def downsample_avgpool(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', negative_slope=0.2): + assert len(mode)<4 and mode[0] in ['2', '3'], 'mode examples: 2, 2R, 2BR, 3, ..., 3BR.' + kernel_size_pool = int(mode[0]) + stride_pool = int(mode[0]) + mode = mode.replace(mode[0], 'AC') + pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0], negative_slope=negative_slope) + pool_tail = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode[1:], negative_slope=negative_slope) + return sequential(pool, pool_tail) + + +def sequential(*args): + """Advanced nn.Sequential. + + Args: + nn.Sequential, nn.Module + + Returns: + nn.Sequential + """ + if len(args) == 1: + if isinstance(args[0], OrderedDict): + raise NotImplementedError('sequential does not support OrderedDict input.') + return args[0] # No sequential is needed. + modules = [] + for module in args: + if isinstance(module, nn.Sequential): + for submodule in module.children(): + modules.append(submodule) + elif isinstance(module, nn.Module): + modules.append(module) + return nn.Sequential(*modules) + + + +# -------------------------------------------- +# return nn.Sequantial of (Conv + BN + ReLU) +# -------------------------------------------- +def conv(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CBR', negative_slope=0.2): + L = [] + for t in mode: + if t == 'C': + L.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)) + elif t == 'T': + L.append(nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)) + elif t == 'B': + L.append(nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-04, affine=True)) + elif t == 'I': + L.append(nn.InstanceNorm2d(out_channels, affine=True)) + elif t == 'R': + L.append(nn.ReLU(inplace=True)) + elif t == 'r': + L.append(nn.ReLU(inplace=False)) + elif t == 'L': + L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=True)) + elif t == 'l': + L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=False)) + elif t == '2': + L.append(nn.PixelShuffle(upscale_factor=2)) + elif t == '3': + L.append(nn.PixelShuffle(upscale_factor=3)) + elif t == '4': + L.append(nn.PixelShuffle(upscale_factor=4)) + elif t == 'U': + L.append(nn.Upsample(scale_factor=2, mode='nearest')) + elif t == 'u': + L.append(nn.Upsample(scale_factor=3, mode='nearest')) + elif t == 'v': + L.append(nn.Upsample(scale_factor=4, mode='nearest')) + elif t == 'M': + L.append(nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=0)) + elif t == 'A': + L.append(nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=0)) + else: + raise NotImplementedError('Undefined type: '.format(t)) + return sequential(*L) + +# -------------------------------------------- +# Functions taken from utils/utils_image.py +# https://github.com/cszn/DPIR/tree/master/utils +# -------------------------------------------- + +''' +# ======================================= +# numpy(single) <---> numpy(unit) +# numpy(single) <---> tensor +# numpy(unit) <---> tensor +# ======================================= +''' + + +# -------------------------------- +# numpy(single) <---> numpy(unit) +# -------------------------------- + + +def uint2single(img): + + return np.float32(img/255.) + + +def single2uint(img): + + return np.uint8((img.clip(0, 1)*255.).round()) + + +def uint162single(img): + + return np.float32(img/65535.) + + +def single2uint16(img): + + return np.uint8((img.clip(0, 1)*65535.).round()) + + +# -------------------------------- +# numpy(unit) <---> tensor +# uint (HxWxn_channels (RGB) or G) +# -------------------------------- + + +# convert uint (HxWxn_channels) to 4-dimensional torch tensor +def uint2tensor4(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0) + + +# convert uint (HxWxn_channels) to 3-dimensional torch tensor +def uint2tensor3(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.) + + +# convert torch tensor to uint +def tensor2uint(img): + img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + return np.uint8((img*255.0).round()) + + +# -------------------------------- +# numpy(single) <---> tensor +# single (HxWxn_channels (RGB) or G) +# -------------------------------- + + +# convert single (HxWxn_channels) to 4-dimensional torch tensor +def single2tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0) + + +def single2tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0) + + +def single32tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0) + + +def single42tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float() + + +# convert single (HxWxn_channels) to 3-dimensional torch tensor +def single2tensor3(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float() + +# convert single (HxWx1, HxW) to 2-dimensional torch tensor +def single2tensor2(img): + return torch.from_numpy(np.ascontiguousarray(img)).squeeze().float() + +# convert torch tensor to single +def tensor2single(img): + img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + + return img + +def tensor2single3(img): + img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + elif img.ndim == 2: + img = np.expand_dims(img, axis=2) + return img + diff --git a/spyrit/tutorial/tuto_train.py b/spyrit/tutorial/train.py similarity index 100% rename from spyrit/tutorial/tuto_train.py rename to spyrit/tutorial/train.py diff --git a/spyrit/tutorial/tuto_core_2d_drunet.py b/spyrit/tutorial/tuto_core_2d_drunet.py new file mode 100644 index 00000000..0a2a3ae2 --- /dev/null +++ b/spyrit/tutorial/tuto_core_2d_drunet.py @@ -0,0 +1,304 @@ +r""" +01. Tutorial 2D - Image reconstruction for single-pixel imaging using pretrained DRUNet denoising network +====================== +This tutorial shows how to simulate data and perform image reconstruction with DC-DRUNet +(data completion with pretrained DRUNet denoising network) for single-pixel imaging. +For data simulation, it loads an image from ImageNet and simulated measurements based on +an undersampled Hadamard operator. You can select number of counts and undersampled factor. + +Image reconstruction is preformed using the following methods: + Pseudo-inverse + PInvNet: Linear net + DCNet: Data completion net with unit matrix denoising + DCUNet: Data completion with UNet denoising, trained on stl10 dataset. + Refer to tuto_run_train_colab.ipynb for an example to train DCUNet. + DCDRUNet: Data completion with pretrained DRUNet denoising. + + DRUNet taken from https://github.com/cszn/DPIR + Deep Plug-and-Play Image Restoration (DPIR) toolbox + June 2023 + +""" + + + +import numpy as np +import os +import matplotlib.pyplot as plt +from spyrit.core.meas import HadamSplit +from spyrit.core.noise import NoNoise, Poisson, PoissonApproxGauss +from spyrit.core.prep import SplitPoisson +from spyrit.core.recon import PseudoInverse, PinvNet, DCNet, DCDRUNet +from spyrit.misc.statistics import Cov2Var, data_loaders_stl10, transform_gray_norm +from spyrit.misc.disp import imagesc +from spyrit.misc.sampling import meas2img2 +from spyrit.core.nnet import Unet +from spyrit.core.train import load_net + +import torch +import torchvision +# pip install girder-client +# pip install gdown +import girder_client +import gdown + +from spyrit.external.drunet import UNetRes as drunet +from spyrit.external.drunet import uint2single, single2tensor4 +#from spyrit.external import drunet_utils as util + +H = 64 # Image height (assumed squared image) +M = H**2 // 4 # Num measurements = subsampled by factor 2 +B = 10 # Batch size +alpha = 100 # ph/pixel max: number of counts +download_cov = True # Dwonload covariance matrix; + # otherwise, set to unit matrix +ind_img = 1 # Image index for image selection + +imgs_path = './spyrit/images' + +cov_name = './stat/Cov_64x64.npy' + +# use GPU, if available +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +############################################################################### +# So far we have been able to estimate our posterion mean. What about its +# uncertainties (i.e., posterion covariance)? + +# Uncomment to download stl10 dataset +# A batch of images +# dataloaders = data_loaders_stl10('../../../data', img_size=H, batch_size=10) +# dataloader = dataloaders['train'] + +# Create a transform for natural images to normalized grayscale image tensors +transform = transform_gray_norm(img_size=H) + +# Create dataset and loader (expects class folder 'images/test/') +dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform) +dataloader = torch.utils.data.DataLoader(dataset, batch_size = min(B, len(dataset))) + +# Select image +x0, _ = next(iter(dataloader)) +x0 = x0[ind_img:6,:,:,:] +x = x0.detach().clone() +b,c,h,w = x.shape +x = x.view(b*c,h*w) +print(f'Shape of incoming image (b*c,h*w): {x.shape}') + +# Operators +# +# Order matrix with shape (H, H) used to compute the permutation matrix +# (as undersampling taking the first rows only) + +if (download_cov is True): + # api Rest url of the warehouse + url='https://pilot-warehouse.creatis.insa-lyon.fr/api/v1' + + # Generate the warehouse client + gc = girder_client.GirderClient(apiUrl=url) + + # Download the covariance matrix and mean image + data_folder = './stat/' + dataId_list = [ + '63935b624d15dd536f0484a5', # for reconstruction (imageNet, 64) + '63935a224d15dd536f048496', # for reconstruction (imageNet, 64) + ] + for dataId in dataId_list: + myfile = gc.getFile(dataId) + gc.downloadFile(dataId, data_folder + myfile['name']) + + print(f'Created {data_folder}') + +try: + Cov = np.load(cov_name) +except: + Cov = np.eye(H*H) + print(f"Cov matrix {cov_name} not found! Set to the identity") + +Ord = Cov2Var(Cov) + +# Measurement operator: +# Computes linear measurements y=Px, where P is a linear operator (matrix) with positive entries +# such that P=[H_{+}; H_{-}]=[max(H,0); max(0,-H)], H=H_{+}-H_{-} +meas_op = HadamSplit(M, H, Ord) + +# Simulates raw split measurements from images in the range [0,1] assuming images provided in range [-1,1] +# y=0.5*H(1 + x) +# noise = NoNoise(meas_op) # noiseless +#noise = Poisson(meas_op, alpha) +noise = PoissonApproxGauss(meas_op, alpha) # faster than Poisson + +# Preprocess the raw data acquired with split measurement operator assuming Poisson noise +prep = SplitPoisson(alpha, meas_op) + +# Reconstruction with pseudoinverse +pinv = PseudoInverse() + +# Reconstruction with for Core module (linear net) +pinvnet = PinvNet(noise, prep) + +# Reconstruction with for DCNet (linear net + denoising net) +dcnet = DCNet(noise, prep, Cov) + +# Pretreined DC UNet (UNet denoising) +denoi = Unet() +dcunet = DCNet(noise, prep, Cov, denoi) + +# Load previously trained UNet model + +# Path to model +models_path = "./model" +model_unet_path = os.path.join(models_path, "dc-net_unet_imagenet_var_N0_10_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_256_reg_1e-07_light") +if os.path.exists(models_path) is False: + os.mkdir(models_path) + print(f'Created {models_path}') + +try: + # Download weights + url_unet = 'https://drive.google.com/file/d/1LBrjU0B-Tecd4GBRozX9-24LTRzIiMzA/view?usp=drive_link' + gdown.download(url_unet, f'{model_unet_path}.pth', quiet=False,fuzzy=True) + + # Load model from path + load_net(model_unet_path, dcunet, device, False) + print(f'Model {model_unet_path} loaded.') + load_unet = True +except: + print(f'Model {model_unet_path} not found!') + load_unet = False + +# DCDRUNet +# +# Download DRUNet weights +url_drunet = 'https://drive.google.com/file/d/1oSsLjPPn6lqtzraFZLZGmwP_5KbPfTES/view?usp=drive_link' +model_drunet_path = os.path.join(models_path, 'drunet_gray.pth') +try: + gdown.download(url_drunet, model_drunet_path, quiet=False,fuzzy=True) + + # Define denoising network + n_channels = 1 # 1 for grayscale image + denoi_drunet = drunet(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R', + downsample_mode="strideconv", upsample_mode="convtranspose") + + # Load pretrained model + denoi_drunet.load_state_dict(torch.load(model_drunet_path), strict=True) + print(f'Model {model_drunet_path} loaded.') + load_drunet = True +except: + print(f'Model {model_drunet_path} not found!') + load_drunet = False + +if load_drunet is True: + denoi_drunet.eval() + for k, v in denoi_drunet.named_parameters(): + v.requires_grad = False + print(sum(map(lambda x: x.numel(), denoi_drunet.parameters())) ) + + # Define DCDRUNet + #noise_level = 10 + #dcdrunet = DCDRUNet(noise, prep, Cov, denoi_drunet, noise_level=noise_level) + dcdrunet = DCDRUNet(noise, prep, Cov, denoi_drunet) + +# Simulate measurements +y = noise(x) +m = prep(y) +print(f'Shape of simulated measurements y: {y.shape}') +print(f'Shape of preprocessed data m: {m.shape}') + +# Reconstructions +# +# Pseudo-inverse +z_pinv = pinv(m, meas_op) +print(f'Shape of reconstructed image z: {z_pinv.shape}') + +# Pseudo-inverse net +pinvnet = pinvnet.to(device) + +x = x0.detach().clone() +x = x.to(device) +z_pinvnet = pinvnet(x) +# z_pinvnet = pinvnet.reconstruct(y) + +# DCNet +#y = pinvnet.acquire(x) # or equivalently here: y = dcnet.acquire(x) +#m = pinvnet.meas2img(y) # zero-padded images (after preprocessing) +dcnet = dcnet.to(device) +z_dcnet = dcnet.reconstruct(y.to(device)) # reconstruct from raw measurements + +# DC UNET +if (load_unet is True): + dcunet = dcunet.to(device) + with torch.no_grad(): + z_dcunet = dcunet.reconstruct(y.to(device)) # reconstruct from raw measurements + +# DC DRUNET +if (load_drunet is True): + # Reconstruct with DCDRUNet + # Uncomment to set a new noise level: The higher the noise, the higher the denoising + noise_level = 10 + dcdrunet.set_noise_level(noise_level) + dcdrunet = dcdrunet.to(device) + with torch.no_grad(): + # reconstruct from raw measurements + z_dcdrunet = dcdrunet.reconstruct(y.to(device)) + + denoi_drunet = denoi_drunet.to(device) + + # ----------- + # Denoise original image with DRUNet + noise_level = 10 + x_sample = 0.5*(x[0,0,:,:] + 1).cpu().numpy() + imagesc(x_sample ,'Ground-truth image normalized', show=False) + + x_sample = uint2single(255*x_sample) + x_sample = single2tensor4(x_sample[:,:,np.newaxis]) + x_sample = torch.cat((x_sample, torch.FloatTensor([noise_level/255.]).repeat(1, 1, x_sample.shape[2], x_sample.shape[3])), dim=1) + x_sample = x_sample.to(device) + z_den_drunet = denoi_drunet(x_sample) + +# Plots +x_plot = x.view(-1,H,H).cpu().numpy() +imagesc(x_plot[0,:,:],'Ground-truth image normalized', show=False) + +m_plot = y.numpy() +m_plot = meas2img2(m_plot.T, Ord) +m_plot = np.moveaxis(m_plot,-1, 0) +imagesc(m_plot[0,:,:],'Simulated Measurement', show=False) + +m_plot = m.numpy() +m_plot = meas2img2(m_plot.T, Ord) +m_plot = np.moveaxis(m_plot,-1, 0) +imagesc(m_plot[0,:,:],'Preprocessed data', show=False) + +m_plot = m.numpy() +m_plot = meas2img2(m_plot.T, Ord) +m_plot = np.moveaxis(m_plot,-1, 0) + +z_plot = z_pinv.view(-1,H,H).numpy() +imagesc(z_plot[0,:,:],'Pseudo-inverse reconstruction', show=False) + +z_plot = z_pinvnet.view(-1,H,H).cpu().numpy() +imagesc(z_plot[0,:,:],'Pseudo-inverse net reconstruction', show=False) + +z_plot = z_dcnet.view(-1,H,H).cpu().numpy() +imagesc(z_plot[0,:,:],'DCNet reconstruction', show=False) + +if (load_unet is True): + z_plot = z_dcunet.view(-1,H,H).detach().cpu().numpy() + imagesc(z_plot[0,:,:],'DC UNet reconstruction', show=False) + +if (load_drunet is True): + # DRUNet denoising + z_plot = z_den_drunet.view(-1,H,H).detach().cpu().numpy() + imagesc(z_plot[0,:,:],'DRUNet denoising of original image', show=False) + + # DCDRUNet + z_plot = z_dcdrunet.view(-1,H,H).detach().cpu().numpy() + imagesc(z_plot[0,:,:],f'DC DRUNet reconstruction noise level={noise_level}', show=False) + +plt.show() + +############################################################################### +# Note that here we have been able to compute a sample posterior covariance +# from its estimated samples. By displaying it we can see how both the overall +# variances and the correlation between different parameters have become +# narrower compared to their prior counterparts. diff --git a/spyrit/tutorial/tuto_core_2d_short.ipynb b/spyrit/tutorial/tuto_core_2d_short.ipynb index 325faec4..137ece26 100644 --- a/spyrit/tutorial/tuto_core_2d_short.ipynb +++ b/spyrit/tutorial/tuto_core_2d_short.ipynb @@ -5,7 +5,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openspyrit/spyrit/blob/demo_colab/spyrit/tutorial/tuto_core_2d_short.ipynb)" + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openspyrit/spyrit/blob/master/spyrit/tutorial/tuto_core_2d_short.ipynb)" ] }, { @@ -67,7 +67,7 @@ "source": [ "# Set download data covariance to True for realistic simulations\n", "# It taken a few minutes to download the data\n", - "download_cov = False" + "download_cov = True" ] }, { @@ -96,7 +96,7 @@ }, "outputs": [], "source": [ - "mode_colab = False\n", + "mode_colab = True\n", "if (mode_colab is True):\n", " # Connect to googledrive\n", " #if 'google.colab' in str(get_ipython()):\n", @@ -172,7 +172,7 @@ "outputs": [], "source": [ "# #%%capture\n", - "install_spyrit = False\n", + "install_spyrit = True\n", "if (mode_colab is True):\n", " if install_spyrit is True:\n", " # Clone and install\n", @@ -182,7 +182,6 @@ "\n", " # Checkout to ongoing branch\n", " !git fetch --all\n", - " !git checkout demo_colab\n", " !pip install girder_client\n", " else:\n", " # cd to spyrit folder is already cloned in your drive\n", @@ -388,7 +387,7 @@ "Data simulation in spyrit is done by using three operators from using *spyrit.core.meas*: image normalization, split measurements and noise perturbation. In the example below, this corresponds to the following steps:\n", "\n", "$$\n", - "x \\xrightarrow[\\text{Step 1}]{\\text{NoNoise}} \\frac{x+1}{2} \\xrightarrow[\\text{Step 2}]{\\text{HadamSplit}} Px \\xrightarrow[\\text{Step 3}]{\\text{Poisson}} y\n", + "x \\xrightarrow[\\text{Step 1}]{\\text{NoNoise}} \\frac{x+1}{2} \\xrightarrow[\\text{Step 2}]{\\text{HadamSplit}} y=Px \\xrightarrow[\\text{Step 3}]{\\text{Poisson}} \\mathcal{P}(\\alpha y)\n", "$$\n", "\n", "- Step 1: Given an image $x$ between $[-1, 1]$, the image is first normalized between $[0, 1]$ as\n", @@ -417,7 +416,7 @@ "- Step 3: Data is finally perturbed by Poisson noise as\n", "\n", "$$\n", - "y = \\mathcal{P}(\\alpha Px)\n", + "\\tilde{y} = \\mathcal{P}(\\alpha y)\n", "$$\n", "\n", "using spirit's *spyrit.core.noise.Poisson* :\n", diff --git a/spyrit/tutorial/tuto_run_train_colab.ipynb b/spyrit/tutorial/tuto_train_colab.ipynb similarity index 88% rename from spyrit/tutorial/tuto_run_train_colab.ipynb rename to spyrit/tutorial/tuto_train_colab.ipynb index eface33b..63b3b6d0 100644 --- a/spyrit/tutorial/tuto_run_train_colab.ipynb +++ b/spyrit/tutorial/tuto_train_colab.ipynb @@ -5,7 +5,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openspyrit/spyrit/blob/demo_colab/spyrit/tutorial/tuto_run_train_colab.ipynb)" + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openspyrit/spyrit/blob/master/spyrit/tutorial/tuto_train_colab.ipynb)" ] }, { @@ -13,7 +13,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Demo to train DCNET for 2D single-pixel reconstruction" + "# Tutorial to train a reconstruction network " ] }, { @@ -21,9 +21,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Demo to train DCNET (data completion with UNet denoising with 0.5 M trainable parameters) for 2D single-pixel imaging on stl10.\n", + "Tutorial to train a reconstruction network for 2D single-pixel imaging on stl10.\n", "\n", - "The first time, it installs spyrit package and spas package for the data (optional). Experiment tracking and visualization with tensorboard is optional.\n" + "Current example trains DCNET (data completion with UNet denoising with 0.5 M trainable parameters). " ] }, { @@ -131,7 +131,6 @@ "\n", " # Checkout to ongoing branch\n", " !git fetch --all\n", - " !git checkout demo_colab\n", " else:\n", " # cd to spyrit folder is already cloned in your drive\n", " %cd /content/gdrive/MyDrive/Colab_Notebooks/openspyrit/spyrit\n", @@ -242,11 +241,11 @@ "if (mode_colab is True):\n", " # Copy tuto_train.py to main directory for colab\n", " !pwd\n", - " !cp spyrit/tutorial/tuto_train.py .\n", - " !python3 tuto_train.py --N0 $N0 --M $M --data_root $data_root --data $data --stat_root $stat_root --tb_path $tb_path --tb_prof $tb_prof\n", - " !rm tuto_train.py\n", + " !cp spyrit/tutorial/train.py .\n", + " !python3 train.py --N0 $N0 --M $M --data_root $data_root --data $data --stat_root $stat_root --tb_path $tb_path --tb_prof $tb_prof\n", + " !rm train.py\n", "else:\n", - " !python3 spyrit/tutorial/tuto_train.py --N0 $N0 --M $M --data_root $data_root --data $data --stat_root $stat_root --tb_path $tb_path --tb_prof $tb_prof" + " !python3 spyrit/tutorial/train.py --N0 $N0 --M $M --data_root $data_root --data $data --stat_root $stat_root --tb_path $tb_path --tb_prof $tb_prof" ] }, {