From 24ee6f4e5c06d20a2a0c5936b5f36938e6fc3f3e Mon Sep 17 00:00:00 2001 From: jabascal Date: Thu, 6 Jul 2023 11:07:48 +0200 Subject: [PATCH 1/6] Added external module and required drunet.py functions --- spyrit/external/__init__.py | 0 spyrit/external/drunet.py | 369 ++++++++++++++++++++++++++++++++++++ 2 files changed, 369 insertions(+) create mode 100644 spyrit/external/__init__.py create mode 100644 spyrit/external/drunet.py diff --git a/spyrit/external/__init__.py b/spyrit/external/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/spyrit/external/drunet.py b/spyrit/external/drunet.py new file mode 100644 index 00000000..41b71db3 --- /dev/null +++ b/spyrit/external/drunet.py @@ -0,0 +1,369 @@ +import torch +import torch.nn as nn +import numpy as np +from collections import OrderedDict + +''' + Modified by J Abascal https://github.com/cszn/DPIR/blob/master/models/network_unet.py + June 2023 + Plug-and-Play Image Restoration with Deep Denoiser Prior +''' + + +class UNetRes(nn.Module): + def __init__(self, in_nc=1, out_nc=1, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode='strideconv', upsample_mode='convtranspose'): + super(UNetRes, self).__init__() + + self.m_head = conv(in_nc, nc[0], bias=False, mode='C') + + # downsample + if downsample_mode == 'avgpool': + downsample_block = downsample_avgpool + elif downsample_mode == 'maxpool': + downsample_block = downsample_maxpool + elif downsample_mode == 'strideconv': + downsample_block = downsample_strideconv + else: + raise NotImplementedError('downsample mode [{:s}] is not found'.format(downsample_mode)) + + self.m_down1 = sequential(*[ResBlock(nc[0], nc[0], bias=False, mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[0], nc[1], bias=False, mode='2')) + self.m_down2 = sequential(*[ResBlock(nc[1], nc[1], bias=False, mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[1], nc[2], bias=False, mode='2')) + self.m_down3 = sequential(*[ResBlock(nc[2], nc[2], bias=False, mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[2], nc[3], bias=False, mode='2')) + + self.m_body = sequential(*[ResBlock(nc[3], nc[3], bias=False, mode='C'+act_mode+'C') for _ in range(nb)]) + + # upsample + if upsample_mode == 'upconv': + upsample_block = upsample_upconv + elif upsample_mode == 'pixelshuffle': + upsample_block = upsample_pixelshuffle + elif upsample_mode == 'convtranspose': + upsample_block = upsample_convtranspose + else: + raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) + + self.m_up3 = sequential(upsample_block(nc[3], nc[2], bias=False, mode='2'), *[ResBlock(nc[2], nc[2], bias=False, mode='C'+act_mode+'C') for _ in range(nb)]) + self.m_up2 = sequential(upsample_block(nc[2], nc[1], bias=False, mode='2'), *[ResBlock(nc[1], nc[1], bias=False, mode='C'+act_mode+'C') for _ in range(nb)]) + self.m_up1 = sequential(upsample_block(nc[1], nc[0], bias=False, mode='2'), *[ResBlock(nc[0], nc[0], bias=False, mode='C'+act_mode+'C') for _ in range(nb)]) + + self.m_tail = conv(nc[0], out_nc, bias=False, mode='C') + + def forward(self, x0): + x1 = self.m_head(x0) + x2 = self.m_down1(x1) + x3 = self.m_down2(x2) + x4 = self.m_down3(x3) + x = self.m_body(x4) + x = self.m_up3(x+x4) + x = self.m_up2(x+x3) + x = self.m_up1(x+x2) + x = self.m_tail(x+x1) + + return x + + + +# ---------------------------------------------- +# Functions taken from basicblock.py +# https://github.com/cszn/DPIR/tree/master/models +# ---------------------------------------------- + + +# -------------------------------------------- +# Res Block: x + conv(relu(conv(x))) +# -------------------------------------------- +class ResBlock(nn.Module): + def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CRC', negative_slope=0.2): + super(ResBlock, self).__init__() + + assert in_channels == out_channels, 'Only support in_channels==out_channels.' + if mode[0] in ['R', 'L']: + mode = mode[0].lower() + mode[1:] + + self.res = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope) + + def forward(self, x): + #res = self.res(x) + return x + self.res(x) + + + +""" +# -------------------------------------------- +# Upsampler +# Kai Zhang, https://github.com/cszn/KAIR +# -------------------------------------------- +# upsample_pixelshuffle +# upsample_upconv +# upsample_convtranspose +# -------------------------------------------- +""" + + +# -------------------------------------------- +# conv + subp (+ relu) +# -------------------------------------------- +def upsample_pixelshuffle(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', negative_slope=0.2): + assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.' + up1 = conv(in_channels, out_channels * (int(mode[0]) ** 2), kernel_size, stride, padding, bias, mode='C'+mode, negative_slope=negative_slope) + return up1 + + +# -------------------------------------------- +# nearest_upsample + conv (+ R) +# -------------------------------------------- +def upsample_upconv(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', negative_slope=0.2): + assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR' + if mode[0] == '2': + uc = 'UC' + elif mode[0] == '3': + uc = 'uC' + elif mode[0] == '4': + uc = 'vC' + mode = mode.replace(mode[0], uc) + up1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode, negative_slope=negative_slope) + return up1 + + +# -------------------------------------------- +# convTranspose (+ relu) +# -------------------------------------------- +def upsample_convtranspose(in_channels=64, out_channels=3, kernel_size=2, stride=2, padding=0, bias=True, mode='2R', negative_slope=0.2): + assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.' + kernel_size = int(mode[0]) + stride = int(mode[0]) + mode = mode.replace(mode[0], 'T') + up1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope) + return up1 + + +''' +# -------------------------------------------- +# Downsampler +# Kai Zhang, https://github.com/cszn/KAIR +# -------------------------------------------- +# downsample_strideconv +# downsample_maxpool +# downsample_avgpool +# -------------------------------------------- +''' + + +# -------------------------------------------- +# strideconv (+ relu) +# -------------------------------------------- +def downsample_strideconv(in_channels=64, out_channels=64, kernel_size=2, stride=2, padding=0, bias=True, mode='2R', negative_slope=0.2): + assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.' + kernel_size = int(mode[0]) + stride = int(mode[0]) + mode = mode.replace(mode[0], 'C') + down1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope) + return down1 + + +# -------------------------------------------- +# maxpooling + conv (+ relu) +# -------------------------------------------- +def downsample_maxpool(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0, bias=True, mode='2R', negative_slope=0.2): + assert len(mode)<4 and mode[0] in ['2', '3'], 'mode examples: 2, 2R, 2BR, 3, ..., 3BR.' + kernel_size_pool = int(mode[0]) + stride_pool = int(mode[0]) + mode = mode.replace(mode[0], 'MC') + pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0], negative_slope=negative_slope) + pool_tail = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode[1:], negative_slope=negative_slope) + return sequential(pool, pool_tail) + + +# -------------------------------------------- +# averagepooling + conv (+ relu) +# -------------------------------------------- +def downsample_avgpool(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', negative_slope=0.2): + assert len(mode)<4 and mode[0] in ['2', '3'], 'mode examples: 2, 2R, 2BR, 3, ..., 3BR.' + kernel_size_pool = int(mode[0]) + stride_pool = int(mode[0]) + mode = mode.replace(mode[0], 'AC') + pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0], negative_slope=negative_slope) + pool_tail = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode[1:], negative_slope=negative_slope) + return sequential(pool, pool_tail) + + +def sequential(*args): + """Advanced nn.Sequential. + + Args: + nn.Sequential, nn.Module + + Returns: + nn.Sequential + """ + if len(args) == 1: + if isinstance(args[0], OrderedDict): + raise NotImplementedError('sequential does not support OrderedDict input.') + return args[0] # No sequential is needed. + modules = [] + for module in args: + if isinstance(module, nn.Sequential): + for submodule in module.children(): + modules.append(submodule) + elif isinstance(module, nn.Module): + modules.append(module) + return nn.Sequential(*modules) + + + +# -------------------------------------------- +# return nn.Sequantial of (Conv + BN + ReLU) +# -------------------------------------------- +def conv(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CBR', negative_slope=0.2): + L = [] + for t in mode: + if t == 'C': + L.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)) + elif t == 'T': + L.append(nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)) + elif t == 'B': + L.append(nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-04, affine=True)) + elif t == 'I': + L.append(nn.InstanceNorm2d(out_channels, affine=True)) + elif t == 'R': + L.append(nn.ReLU(inplace=True)) + elif t == 'r': + L.append(nn.ReLU(inplace=False)) + elif t == 'L': + L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=True)) + elif t == 'l': + L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=False)) + elif t == '2': + L.append(nn.PixelShuffle(upscale_factor=2)) + elif t == '3': + L.append(nn.PixelShuffle(upscale_factor=3)) + elif t == '4': + L.append(nn.PixelShuffle(upscale_factor=4)) + elif t == 'U': + L.append(nn.Upsample(scale_factor=2, mode='nearest')) + elif t == 'u': + L.append(nn.Upsample(scale_factor=3, mode='nearest')) + elif t == 'v': + L.append(nn.Upsample(scale_factor=4, mode='nearest')) + elif t == 'M': + L.append(nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=0)) + elif t == 'A': + L.append(nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=0)) + else: + raise NotImplementedError('Undefined type: '.format(t)) + return sequential(*L) + +# -------------------------------------------- +# Functions taken from utils/utils_image.py +# https://github.com/cszn/DPIR/tree/master/utils +# -------------------------------------------- + +''' +# ======================================= +# numpy(single) <---> numpy(unit) +# numpy(single) <---> tensor +# numpy(unit) <---> tensor +# ======================================= +''' + + +# -------------------------------- +# numpy(single) <---> numpy(unit) +# -------------------------------- + + +def uint2single(img): + + return np.float32(img/255.) + + +def single2uint(img): + + return np.uint8((img.clip(0, 1)*255.).round()) + + +def uint162single(img): + + return np.float32(img/65535.) + + +def single2uint16(img): + + return np.uint8((img.clip(0, 1)*65535.).round()) + + +# -------------------------------- +# numpy(unit) <---> tensor +# uint (HxWxn_channels (RGB) or G) +# -------------------------------- + + +# convert uint (HxWxn_channels) to 4-dimensional torch tensor +def uint2tensor4(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0) + + +# convert uint (HxWxn_channels) to 3-dimensional torch tensor +def uint2tensor3(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.) + + +# convert torch tensor to uint +def tensor2uint(img): + img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + return np.uint8((img*255.0).round()) + + +# -------------------------------- +# numpy(single) <---> tensor +# single (HxWxn_channels (RGB) or G) +# -------------------------------- + + +# convert single (HxWxn_channels) to 4-dimensional torch tensor +def single2tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0) + + +def single2tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0) + + +def single32tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0) + + +def single42tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float() + + +# convert single (HxWxn_channels) to 3-dimensional torch tensor +def single2tensor3(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float() + +# convert single (HxWx1, HxW) to 2-dimensional torch tensor +def single2tensor2(img): + return torch.from_numpy(np.ascontiguousarray(img)).squeeze().float() + +# convert torch tensor to single +def tensor2single(img): + img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + + return img + +def tensor2single3(img): + img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + elif img.ndim == 2: + img = np.expand_dims(img, axis=2) + return img + From bff27a4fe80ff861adcf6aed682465e7b0624b39 Mon Sep 17 00:00:00 2001 From: jabascal Date: Thu, 6 Jul 2023 11:19:36 +0200 Subject: [PATCH 2/6] Created DCDRUNET(DCNET) class: data completion with pretrained DRUNet denoising. Demo compares PInvNet, DCNET and DRUNET for undersampled data perturbed with Poisson noise --- spyrit/core/recon.py | 147 ++++++++- spyrit/tutorial/tuto_core_2d_short_drunet.py | 302 +++++++++++++++++++ 2 files changed, 448 insertions(+), 1 deletion(-) create mode 100644 spyrit/tutorial/tuto_core_2d_short_drunet.py diff --git a/spyrit/core/recon.py b/spyrit/core/recon.py index 4893c1ed..d8f6bc77 100644 --- a/spyrit/core/recon.py +++ b/spyrit/core/recon.py @@ -742,4 +742,149 @@ def reconstruct_expe(self, x): # Denormalization x = self.prep.denormalize_expe(x, norm, self.Acq.meas_op.h, self.Acq.meas_op.w) - return x \ No newline at end of file + return x + +#%%=========================================================================================== +class DCDRUNet(DCNet): +# =========================================================================================== + r""" Denoised completion reconstruction network based on DRUNet wich concatenates a + noise level map to the input + + .. math: + + + Args: + :attr:`noise`: Acquisition operator (see :class:`~spyrit.core.noise`) + + :attr:`prep`: Preprocessing operator (see :class:`~spyrit.core.prep`) + + :attr:`sigma`: UPDATE!! Tikhonov reconstruction operator of type + :class:`~spyrit.core.recon.TikhonovMeasurementPriorDiag()` + + :attr:`denoi` (optional): Image denoising operator + (see :class:`~spyrit.core.nnet`). + Default :class:`~spyrit.core.nnet.Identity` + + :attr:`noise_level` (optional): Noise level in the range [0, 50] for an + image between [0, 255], default is noise_level=5 + + + Input / Output: + :attr:`input`: Ground-truth images with concatenated noise level map with + shape :math:`(B,C+1,H,W)` + + :attr:`output`: Reconstructed images with shape :math:`(B,C,H,W)` + + Attributes: + :attr:`Acq`: Acquisition operator initialized as :attr:`noise` + + :attr:`PreP`: Preprocessing operator initialized as :attr:`prep` + + :attr:`DC_Layer`: Data consistency layer initialized as :attr:`tikho` + + :attr:`Denoi`: Image (DRUNet architecture type) denoising operator + initialized as :attr:`denoi` + + + Example: + >>> B, C, H, M = 10, 1, 64, 64**2 + >>> Ord = np.ones((H,H)) + >>> meas = HadamSplit(M, H, Ord) + >>> noise = NoNoise(meas) + >>> prep = SplitPoisson(1.0, M, H*H) + >>> sigma = np.random.random([H**2, H**2]) + >>> n_channels = 1 # 1 for grayscale image + >>> model_drunet_path = './spyrit/drunet/model_zoo/drunet_gray.pth' + >>> denoi_drunet = drunet(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R', + downsample_mode="strideconv", upsample_mode="convtranspose") + >>> recnet = DCDRUNet(noise,prep,sigma,denoi_drunet) + >>> z = recnet(x) + >>> print(z.shape) + torch.Size([10, 1, 64, 64]) + """ + + def __init__(self, + noise, + prep, + sigma, + denoi=nn.Identity(), + noise_level=5): + super().__init__(noise, prep, sigma, denoi) + self.register_buffer('noise_level', torch.FloatTensor([noise_level/255.])) + + def reconstruct(self, x): + r""" Reconstruction step of a reconstruction network + + Args: + :attr:`x`: raw measurement vectors + + Shape: + :attr:`x`: raw measurement vectors with shape :math:`(BC,2M)` + + :attr:`output`: reconstructed images with shape :math:`(BC,1,H,W)` + + Example: + >>> B, C, H, M = 10, 1, 64, 64**2 + >>> Ord = np.ones((H,H)) + >>> meas = HadamSplit(M, H, Ord) + >>> noise = NoNoise(meas) + >>> prep = SplitPoisson(1.0, M, H*H) + >>> sigma = np.random.random([H**2, H**2]) + >>> n_channels = 1 # 1 for grayscale image + >>> model_drunet_path = './spyrit/drunet/model_zoo/drunet_gray.pth' + >>> denoi_drunet = drunet(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R', + downsample_mode="strideconv", upsample_mode="convtranspose") + >>> recnet = DCDRUNet(noise,prep,sigma,denoi_drunet) + >>> x = torch.rand((B*C,2*M), dtype=torch.float) + >>> z = recnet.reconstruct(x) + >>> print(z.shape) + torch.Size([10, 1, 64, 64]) + """ + # x of shape [b*c, 2M] + + bc, _ = x.shape + + # Preprocessing + var_noi = self.prep.sigma(x) + x = self.prep(x) # shape x = [b*c, M] + + # measurements to image domain processing + x_0 = torch.zeros((bc, self.Acq.meas_op.N), device=x.device) + x = self.tikho(x, x_0, var_noi, self.Acq.meas_op) + x = x.view(bc,1,self.Acq.meas_op.h, self.Acq.meas_op.w) # shape x = [b*c,1,h,w] + + # Image domain denoising + x = self.concat_noise_map(x) + x = self.denoi(x) + + return x + + def concat_noise_map(self, x): + r""" Concatenation of noise level map to reconstructed images + + Args: + :attr:`x`: reconstructed images from the reconstruction layer + + Shape: + :attr:`x`: reconstructed images with shape :math:`(BC,1,H,W)` + + :attr:`output`: reconstructed images with concatenated noise level map with shape :math:`(BC,2,H,W)` + """ + + b, c, h, w = x.shape + x = 0.5*(x + 1) + x = torch.cat((x, self.noise_level.expand(b, 1, h, w)), dim=1) + return x + + def set_noise_level(self, noise_level): + r""" Reset noise level value + + Args: + :attr:`noise_level`: noise level value + + Shape: + :attr:`noise_level`: float value noise level :math:`(1)` + + :attr:`output`: noise level tensor with shape :math:`(1)` + """ + self.noise_level = torch.FloatTensor([noise_level/255.]) diff --git a/spyrit/tutorial/tuto_core_2d_short_drunet.py b/spyrit/tutorial/tuto_core_2d_short_drunet.py new file mode 100644 index 00000000..83241e13 --- /dev/null +++ b/spyrit/tutorial/tuto_core_2d_short_drunet.py @@ -0,0 +1,302 @@ +r""" +01. Tutorial 2D - Image reconstruction for single-pixel imaging using pretrained DRUNet denoising network +====================== +This tutorial focuses on Bayesian inversion, a special type of inverse problem +that aims at incorporating prior information in terms of model and data +probabilities in the inversion process. + +It shows how to simulate data and perform image reconstruction with spyrit toolbox. +For data simulation, it loads an image from ImageNet and simulated measurements based on +an undersampled Hadamard operator. You can select number of counts and undersampled factor. + +Image reconstruction is preformed using the following methods: + Pseudo-inverse + PInvNet: Linear net + DCNet: Data completion net with unit matrix denoising + DCUNet: Data completion with UNet denoising, trained on stl10 dataset. + Refer to tuto_run_train_colab.ipynb for an example to train DCUNet. + DCUNetRes: Data completion with pretrained DRUNet denoising. + + DRUNet taken from https://github.com/cszn/DPIR + Deep Plug-and-Play Image Restoration (DPIR) toolbox + June 2023 + +""" + + + +import numpy as np +import os +import matplotlib.pyplot as plt +from spyrit.core.meas import HadamSplit +from spyrit.core.noise import NoNoise, Poisson, PoissonApproxGauss +from spyrit.core.prep import SplitPoisson +from spyrit.core.recon import PseudoInverse, PinvNet, DCNet, DCDRUNet +from spyrit.misc.statistics import Cov2Var, data_loaders_stl10, transform_gray_norm +from spyrit.misc.disp import imagesc +from spyrit.misc.sampling import meas2img2 +from spyrit.core.nnet import Unet +from spyrit.core.train import load_net + +import torch +import torchvision +import girder_client +import gdown + +from spyrit.external.drunet import UNetRes as drunet +from spyrit.external.drunet import uint2single, single2tensor4 +#from spyrit.external import drunet_utils as util + +H = 64 # Image height (assumed squared image) +M = H**2 // 4 # Num measurements = subsampled by factor 2 +B = 10 # Batch size +alpha = 100 # ph/pixel max: number of counts +download_cov = True # Dwonload covariance matrix; + # otherwise, set to unit matrix +load_unet = True # Load pretrained UNet denoising +load_drunet = True # Load pretrained DRUNet denoising +ind_img = 1 # Image index for image selection + +imgs_path = './spyrit/images' + +cov_name = './stat/Cov_64x64.npy' + +# use GPU, if available +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +############################################################################### +# So far we have been able to estimate our posterion mean. What about its +# uncertainties (i.e., posterion covariance)? + +# Uncomment to download stl10 dataset +# A batch of images +# dataloaders = data_loaders_stl10('../../../data', img_size=H, batch_size=10) +# dataloader = dataloaders['train'] + +# Create a transform for natural images to normalized grayscale image tensors +transform = transform_gray_norm(img_size=H) + +# Create dataset and loader (expects class folder 'images/test/') +dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform) +dataloader = torch.utils.data.DataLoader(dataset, batch_size = min(B, len(dataset))) + +# Select image +x0, _ = next(iter(dataloader)) +x0 = x0[ind_img:6,:,:,:] +x = x0.detach().clone() +b,c,h,w = x.shape +x = x.view(b*c,h*w) +print(f'Shape of incoming image (b*c,h*w): {x.shape}') + +# Operators +# +# Order matrix with shape (H, H) used to compute the permutation matrix +# (as undersampling taking the first rows only) + +if (download_cov is True): + # api Rest url of the warehouse + url='https://pilot-warehouse.creatis.insa-lyon.fr/api/v1' + + # Generate the warehouse client + gc = girder_client.GirderClient(apiUrl=url) + + # Download the covariance matrix and mean image + data_folder = './stat/' + dataId_list = [ + '63935b624d15dd536f0484a5', # for reconstruction (imageNet, 64) + '63935a224d15dd536f048496', # for reconstruction (imageNet, 64) + ] + for dataId in dataId_list: + myfile = gc.getFile(dataId) + gc.downloadFile(dataId, data_folder + myfile['name']) + + print(f'Created {data_folder}') + +try: + Cov = np.load(cov_name) +except: + Cov = np.eye(H*H) + print(f"Cov matrix {cov_name} not found! Set to the identity") + +Ord = Cov2Var(Cov) + +# Measurement operator: +# Computes linear measurements y=Px, where P is a linear operator (matrix) with positive entries +# such that P=[H_{+}; H_{-}]=[max(H,0); max(0,-H)], H=H_{+}-H_{-} +meas_op = HadamSplit(M, H, Ord) + +# Simulates raw split measurements from images in the range [0,1] assuming images provided in range [-1,1] +# y=0.5*H(1 + x) +# noise = NoNoise(meas_op) # noiseless +#noise = Poisson(meas_op, alpha) +noise = PoissonApproxGauss(meas_op, alpha) # faster than Poisson + +# Preprocess the raw data acquired with split measurement operator assuming Poisson noise +prep = SplitPoisson(alpha, meas_op) + +# Reconstruction with pseudoinverse +pinv = PseudoInverse() + +# Reconstruction with for Core module (linear net) +pinvnet = PinvNet(noise, prep) + +# Reconstruction with for DCNet (linear net + denoising net) +dcnet = DCNet(noise, prep, Cov) + +# Pretreined DC UNet (UNet denoising) +denoi = Unet() +dcunet = DCNet(noise, prep, Cov, denoi) + +# Load previously trained model +try: + model_path = "./model/dc-net_unet_imagenet_var_N0_10_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_256_reg_1e-07_light" + #model_path = './model/dc-net_unet_stl10_N0_100_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07.pth' + #dcnet_unet.load_state_dict(torch.load(model_path), loa) + load_net(model_path, dcunet, device, False) + print(f'Model {model_path} loaded.') +except: + print(f'Model {model_path} not found!') + load_unet = False + + +# DCDRUNet +# +# Download weights +model_drunet_path = './spyrit/model_zoo' +url_drunet = 'https://drive.google.com/file/d/1oSsLjPPn6lqtzraFZLZGmwP_5KbPfTES/view?usp=drive_link' + +if os.path.exists(model_drunet_path) is False: + os.mkdir(model_drunet_path) + print(f'Created {model_drunet_path}') + +model_drunet_path = os.path.join(model_drunet_path, 'drunet_gray.pth') +gdown.download(url_drunet, model_drunet_path, quiet=False,fuzzy=True) + +# Define denoising network +n_channels = 1 # 1 for grayscale image +denoi_drunet = drunet(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R', + downsample_mode="strideconv", upsample_mode="convtranspose") + +# Load pretrained model +try: + denoi_drunet.load_state_dict(torch.load(model_drunet_path), strict=True) + print(f'Model {model_drunet_path} loaded.') +except: + print(f'Model {model_path} not found!') + load_drunet = False + +denoi_drunet.eval() +for k, v in denoi_drunet.named_parameters(): + v.requires_grad = False +print(sum(map(lambda x: x.numel(), denoi_drunet.parameters())) ) + +# Define DCDRUNet +#noise_level = 10 +#dcdrunet = DCDRUNet(noise, prep, Cov, denoi_drunet, noise_level=noise_level) +dcdrunet = DCDRUNet(noise, prep, Cov, denoi_drunet) + +# Simulate measurements +y = noise(x) +m = prep(y) +print(f'Shape of simulated measurements y: {y.shape}') +print(f'Shape of preprocessed data m: {m.shape}') + +# Reconstructions +# +# Pseudo-inverse +z_pinv = pinv(m, meas_op) +print(f'Shape of reconstructed image z: {z_pinv.shape}') + +# Pseudo-inverse net +pinvnet = pinvnet.to(device) + +x = x0.detach().clone() +x = x.to(device) +z_pinvnet = pinvnet(x) +# z_pinvnet = pinvnet.reconstruct(y) + +# DCNet +#y = pinvnet.acquire(x) # or equivalently here: y = dcnet.acquire(x) +#m = pinvnet.meas2img(y) # zero-padded images (after preprocessing) +dcnet = dcnet.to(device) +z_dcnet = dcnet.reconstruct(y.to(device)) # reconstruct from raw measurements + +# DC UNET +if (load_unet is True): + dcunet = dcunet.to(device) + with torch.no_grad(): + z_dcunet = dcunet.reconstruct(y.to(device)) # reconstruct from raw measurements + +# DC DRUNET +# Denoise original image +noise_level = 10 +x_sample = 0.5*(x[0,0,:,:] + 1).cpu().numpy() +imagesc(x_sample ,'Ground-truth image normalized', show=False) + +x_sample = uint2single(255*x_sample) +x_sample = single2tensor4(x_sample[:,:,np.newaxis]) +x_sample = torch.cat((x_sample, torch.FloatTensor([noise_level/255.]).repeat(1, 1, x_sample.shape[2], x_sample.shape[3])), dim=1) +x_sample = x_sample.to(device) + +if (load_drunet is True): + # Reconstruct + # Uncomment to set a new noise level: The higher the noise, the higher the denoising + noise_level = 10 + dcdrunet.set_noise_level(noise_level) + dcdrunet = dcdrunet.to(device) + with torch.no_grad(): + # reconstruct from raw measurements + z_dcdrunet = dcdrunet.reconstruct(y.to(device)) + + denoi_drunet = denoi_drunet.to(device) + # Denoise + z_den_drunet = denoi_drunet(x_sample) + + +# Plots +x_plot = x.view(-1,H,H).cpu().numpy() +imagesc(x_plot[0,:,:],'Ground-truth image normalized', show=False) + +m_plot = y.numpy() +m_plot = meas2img2(m_plot.T, Ord) +m_plot = np.moveaxis(m_plot,-1, 0) +imagesc(m_plot[0,:,:],'Simulated Measurement', show=False) + +m_plot = m.numpy() +m_plot = meas2img2(m_plot.T, Ord) +m_plot = np.moveaxis(m_plot,-1, 0) +imagesc(m_plot[0,:,:],'Preprocessed data', show=False) + +m_plot = m.numpy() +m_plot = meas2img2(m_plot.T, Ord) +m_plot = np.moveaxis(m_plot,-1, 0) + +z_plot = z_pinv.view(-1,H,H).numpy() +imagesc(z_plot[0,:,:],'Pseudo-inverse reconstruction', show=False) + +z_plot = z_pinvnet.view(-1,H,H).cpu().numpy() +imagesc(z_plot[0,:,:],'Pseudo-inverse net reconstruction', show=False) + +z_plot = z_dcnet.view(-1,H,H).cpu().numpy() +imagesc(z_plot[0,:,:],'DCNet reconstruction', show=False) + +if (load_unet is True): + z_plot = z_dcunet.view(-1,H,H).detach().cpu().numpy() + imagesc(z_plot[0,:,:],'DC UNet reconstruction', show=False) + +if (load_drunet is True): + # DRUNet denoising + z_plot = z_den_drunet.view(-1,H,H).detach().cpu().numpy() + imagesc(z_plot[0,:,:],'DRUNet denoising of original image', show=False) + + # DCDRUNet + z_plot = z_dcdrunet.view(-1,H,H).detach().cpu().numpy() + imagesc(z_plot[0,:,:],f'DC DRUNet reconstruction noise level={noise_level}', show=False) + +plt.show() + +############################################################################### +# Note that here we have been able to compute a sample posterior covariance +# from its estimated samples. By displaying it we can see how both the overall +# variances and the correlation between different parameters have become +# narrower compared to their prior counterparts. From 4adc9de05cdd0f0d9b47c4d7c6fd391aeeec814e Mon Sep 17 00:00:00 2001 From: jabascal Date: Wed, 12 Jul 2023 10:50:28 +0200 Subject: [PATCH 3/6] Modified file names. Update train notebook links --- spyrit/tutorial/{tuto_train.py => train.py} | 0 ...short_drunet.py => tuto_core_2d_drunet.py} | 0 spyrit/tutorial/tuto_run_train_colab.ipynb | 305 ------------------ 3 files changed, 305 deletions(-) rename spyrit/tutorial/{tuto_train.py => train.py} (100%) rename spyrit/tutorial/{tuto_core_2d_short_drunet.py => tuto_core_2d_drunet.py} (100%) delete mode 100644 spyrit/tutorial/tuto_run_train_colab.ipynb diff --git a/spyrit/tutorial/tuto_train.py b/spyrit/tutorial/train.py similarity index 100% rename from spyrit/tutorial/tuto_train.py rename to spyrit/tutorial/train.py diff --git a/spyrit/tutorial/tuto_core_2d_short_drunet.py b/spyrit/tutorial/tuto_core_2d_drunet.py similarity index 100% rename from spyrit/tutorial/tuto_core_2d_short_drunet.py rename to spyrit/tutorial/tuto_core_2d_drunet.py diff --git a/spyrit/tutorial/tuto_run_train_colab.ipynb b/spyrit/tutorial/tuto_run_train_colab.ipynb deleted file mode 100644 index eface33b..00000000 --- a/spyrit/tutorial/tuto_run_train_colab.ipynb +++ /dev/null @@ -1,305 +0,0 @@ -{ - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openspyrit/spyrit/blob/demo_colab/spyrit/tutorial/tuto_run_train_colab.ipynb)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Demo to train DCNET for 2D single-pixel reconstruction" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Demo to train DCNET (data completion with UNet denoising with 0.5 M trainable parameters) for 2D single-pixel imaging on stl10.\n", - "\n", - "The first time, it installs spyrit package and spas package for the data (optional). Experiment tracking and visualization with tensorboard is optional.\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Settings and requirements" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import datetime" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "First, mount google drive to import modules spyrit modules." - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Set google colab" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "mode_colab = True\n", - "if (mode_colab is True):\n", - " # Connect to googledrive\n", - " #if 'google.colab' in str(get_ipython()):\n", - " # Mount google drive to access files via colab\n", - " from google.colab import drive\n", - " drive.mount(\"/content/gdrive\")\n", - " %cd /content/gdrive/MyDrive/\n", - "\n", - " # For the profiler\n", - " !pip install -U tensorboard-plugin-profile\n", - "\n", - " # Load the TensorBoard notebook extension\n", - " %load_ext tensorboard" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "On colab, hoose GPU at *Runtime/Change runtime type*" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!nvidia-smi" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Clone Spyrit package" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Clone and install spyrit package if not installedClone and install spyrit package if not installed or move to spyrit folder" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "install_spyrit = True\n", - "if (mode_colab is True):\n", - " if install_spyrit is True:\n", - " # Clone and install\n", - " !git clone https://github.com/openspyrit/spyrit.git\n", - " %cd spyrit\n", - " !pip install -e .\n", - "\n", - " # Checkout to ongoing branch\n", - " !git fetch --all\n", - " !git checkout demo_colab\n", - " else:\n", - " # cd to spyrit folder is already cloned in your drive\n", - " %cd /content/gdrive/MyDrive/Colab_Notebooks/openspyrit/spyrit\n", - "\n", - " # Add paths for modules\n", - " import sys\n", - " sys.path.append('./spyrit/core')\n", - " sys.path.append('./spyrit/misc')\n", - " sys.path.append('./spyrit/tutorial')\n", - "else:\n", - " # Change path to spyrit/\n", - " os.chdir('../..')\n", - " !pwd" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Download data" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Download covariance matrix. Alternatively install *openspyrit/spas* package:\n", - "```\n", - "├───stats\n", - "│ ├───Average_64x64.npy\n", - "│ ├───Cov_64x64.npy\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "download_cov = True\n", - "if (download_cov is True):\n", - " !pip install girder-client\n", - " import girder_client\n", - "\n", - " # api Rest url of the warehouse\n", - " url='https://pilot-warehouse.creatis.insa-lyon.fr/api/v1'\n", - " \n", - " # Generate the warehouse client\n", - " gc = girder_client.GirderClient(apiUrl=url)\n", - "\n", - " #%% Download the covariance matrix and mean image\n", - " data_folder = './stat/'\n", - " dataId_list = [\n", - " '63935b624d15dd536f0484a5', # for reconstruction (imageNet, 64)\n", - " '63935a224d15dd536f048496', # for reconstruction (imageNet, 64)\n", - " ]\n", - " for dataId in dataId_list:\n", - " myfile = gc.getFile(dataId)\n", - " gc.downloadFile(dataId, data_folder + myfile['name'])\n", - "\n", - " print(f'Created {data_folder}') \n", - " !ls $data_folder" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Train" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Perturbed by Poisson noise (100 photons) and undersampling factor of 4, on stl10 dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Parameters\n", - "N0 = 100\n", - "M = 1024\n", - "data_root = './data/'\n", - "data = 'stl10'\n", - "stat_root = './stat'\n", - "now = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')\n", - "tb_path = f'runs/runs_stdl10_n100_m1024/{now}' # None\n", - "tb_prof = True # False" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Run tuto_train\n", - "if (mode_colab is True):\n", - " # Copy tuto_train.py to main directory for colab\n", - " !pwd\n", - " !cp spyrit/tutorial/tuto_train.py .\n", - " !python3 tuto_train.py --N0 $N0 --M $M --data_root $data_root --data $data --stat_root $stat_root --tb_path $tb_path --tb_prof $tb_prof\n", - " !rm tuto_train.py\n", - "else:\n", - " !python3 spyrit/tutorial/tuto_train.py --N0 $N0 --M $M --data_root $data_root --data $data --stat_root $stat_root --tb_path $tb_path --tb_prof $tb_prof" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Evaluate the trained model" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Tensorboard" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Launch TensorBoard\n", - "# %tensorboard --logdir $tb_path\n", - "%tensorboard --logdir runs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# If run twice tensorboard\n", - "#!lsof -i:6006\n", - "#!kill -9 17387" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "spy", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.11.3" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} From 1eeee33afbf3a28456f73f0b3fffd47663823534 Mon Sep 17 00:00:00 2001 From: jabascal Date: Wed, 12 Jul 2023 11:46:25 +0200 Subject: [PATCH 4/6] Update recon (comment on noise level) and tutorial (links and downloads) --- spyrit/core/recon.py | 5 +- spyrit/tutorial/tuto_core_2d_drunet.py | 108 +++++++++++++------------ 2 files changed, 57 insertions(+), 56 deletions(-) diff --git a/spyrit/core/recon.py b/spyrit/core/recon.py index d8f6bc77..ce577965 100644 --- a/spyrit/core/recon.py +++ b/spyrit/core/recon.py @@ -765,8 +765,7 @@ class DCDRUNet(DCNet): (see :class:`~spyrit.core.nnet`). Default :class:`~spyrit.core.nnet.Identity` - :attr:`noise_level` (optional): Noise level in the range [0, 50] for an - image between [0, 255], default is noise_level=5 + :attr:`noise_level` (optional): Noise level in the range [0, 255], default is noise_level=5 Input / Output: @@ -880,7 +879,7 @@ def set_noise_level(self, noise_level): r""" Reset noise level value Args: - :attr:`noise_level`: noise level value + :attr:`noise_level`: noise level value in the range [0, 255] Shape: :attr:`noise_level`: float value noise level :math:`(1)` diff --git a/spyrit/tutorial/tuto_core_2d_drunet.py b/spyrit/tutorial/tuto_core_2d_drunet.py index 83241e13..0a2a3ae2 100644 --- a/spyrit/tutorial/tuto_core_2d_drunet.py +++ b/spyrit/tutorial/tuto_core_2d_drunet.py @@ -1,11 +1,8 @@ r""" 01. Tutorial 2D - Image reconstruction for single-pixel imaging using pretrained DRUNet denoising network ====================== -This tutorial focuses on Bayesian inversion, a special type of inverse problem -that aims at incorporating prior information in terms of model and data -probabilities in the inversion process. - -It shows how to simulate data and perform image reconstruction with spyrit toolbox. +This tutorial shows how to simulate data and perform image reconstruction with DC-DRUNet +(data completion with pretrained DRUNet denoising network) for single-pixel imaging. For data simulation, it loads an image from ImageNet and simulated measurements based on an undersampled Hadamard operator. You can select number of counts and undersampled factor. @@ -15,7 +12,7 @@ DCNet: Data completion net with unit matrix denoising DCUNet: Data completion with UNet denoising, trained on stl10 dataset. Refer to tuto_run_train_colab.ipynb for an example to train DCUNet. - DCUNetRes: Data completion with pretrained DRUNet denoising. + DCDRUNet: Data completion with pretrained DRUNet denoising. DRUNet taken from https://github.com/cszn/DPIR Deep Plug-and-Play Image Restoration (DPIR) toolbox @@ -40,6 +37,8 @@ import torch import torchvision +# pip install girder-client +# pip install gdown import girder_client import gdown @@ -53,8 +52,6 @@ alpha = 100 # ph/pixel max: number of counts download_cov = True # Dwonload covariance matrix; # otherwise, set to unit matrix -load_unet = True # Load pretrained UNet denoising -load_drunet = True # Load pretrained DRUNet denoising ind_img = 1 # Image index for image selection imgs_path = './spyrit/images' @@ -147,53 +144,59 @@ denoi = Unet() dcunet = DCNet(noise, prep, Cov, denoi) -# Load previously trained model +# Load previously trained UNet model + +# Path to model +models_path = "./model" +model_unet_path = os.path.join(models_path, "dc-net_unet_imagenet_var_N0_10_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_256_reg_1e-07_light") +if os.path.exists(models_path) is False: + os.mkdir(models_path) + print(f'Created {models_path}') + try: - model_path = "./model/dc-net_unet_imagenet_var_N0_10_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_256_reg_1e-07_light" - #model_path = './model/dc-net_unet_stl10_N0_100_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07.pth' - #dcnet_unet.load_state_dict(torch.load(model_path), loa) - load_net(model_path, dcunet, device, False) - print(f'Model {model_path} loaded.') + # Download weights + url_unet = 'https://drive.google.com/file/d/1LBrjU0B-Tecd4GBRozX9-24LTRzIiMzA/view?usp=drive_link' + gdown.download(url_unet, f'{model_unet_path}.pth', quiet=False,fuzzy=True) + + # Load model from path + load_net(model_unet_path, dcunet, device, False) + print(f'Model {model_unet_path} loaded.') + load_unet = True except: - print(f'Model {model_path} not found!') + print(f'Model {model_unet_path} not found!') load_unet = False - # DCDRUNet # -# Download weights -model_drunet_path = './spyrit/model_zoo' +# Download DRUNet weights url_drunet = 'https://drive.google.com/file/d/1oSsLjPPn6lqtzraFZLZGmwP_5KbPfTES/view?usp=drive_link' +model_drunet_path = os.path.join(models_path, 'drunet_gray.pth') +try: + gdown.download(url_drunet, model_drunet_path, quiet=False,fuzzy=True) -if os.path.exists(model_drunet_path) is False: - os.mkdir(model_drunet_path) - print(f'Created {model_drunet_path}') - -model_drunet_path = os.path.join(model_drunet_path, 'drunet_gray.pth') -gdown.download(url_drunet, model_drunet_path, quiet=False,fuzzy=True) - -# Define denoising network -n_channels = 1 # 1 for grayscale image -denoi_drunet = drunet(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R', - downsample_mode="strideconv", upsample_mode="convtranspose") + # Define denoising network + n_channels = 1 # 1 for grayscale image + denoi_drunet = drunet(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R', + downsample_mode="strideconv", upsample_mode="convtranspose") -# Load pretrained model -try: + # Load pretrained model denoi_drunet.load_state_dict(torch.load(model_drunet_path), strict=True) print(f'Model {model_drunet_path} loaded.') + load_drunet = True except: - print(f'Model {model_path} not found!') + print(f'Model {model_drunet_path} not found!') load_drunet = False -denoi_drunet.eval() -for k, v in denoi_drunet.named_parameters(): - v.requires_grad = False -print(sum(map(lambda x: x.numel(), denoi_drunet.parameters())) ) +if load_drunet is True: + denoi_drunet.eval() + for k, v in denoi_drunet.named_parameters(): + v.requires_grad = False + print(sum(map(lambda x: x.numel(), denoi_drunet.parameters())) ) -# Define DCDRUNet -#noise_level = 10 -#dcdrunet = DCDRUNet(noise, prep, Cov, denoi_drunet, noise_level=noise_level) -dcdrunet = DCDRUNet(noise, prep, Cov, denoi_drunet) + # Define DCDRUNet + #noise_level = 10 + #dcdrunet = DCDRUNet(noise, prep, Cov, denoi_drunet, noise_level=noise_level) + dcdrunet = DCDRUNet(noise, prep, Cov, denoi_drunet) # Simulate measurements y = noise(x) @@ -228,18 +231,8 @@ z_dcunet = dcunet.reconstruct(y.to(device)) # reconstruct from raw measurements # DC DRUNET -# Denoise original image -noise_level = 10 -x_sample = 0.5*(x[0,0,:,:] + 1).cpu().numpy() -imagesc(x_sample ,'Ground-truth image normalized', show=False) - -x_sample = uint2single(255*x_sample) -x_sample = single2tensor4(x_sample[:,:,np.newaxis]) -x_sample = torch.cat((x_sample, torch.FloatTensor([noise_level/255.]).repeat(1, 1, x_sample.shape[2], x_sample.shape[3])), dim=1) -x_sample = x_sample.to(device) - if (load_drunet is True): - # Reconstruct + # Reconstruct with DCDRUNet # Uncomment to set a new noise level: The higher the noise, the higher the denoising noise_level = 10 dcdrunet.set_noise_level(noise_level) @@ -249,9 +242,18 @@ z_dcdrunet = dcdrunet.reconstruct(y.to(device)) denoi_drunet = denoi_drunet.to(device) - # Denoise - z_den_drunet = denoi_drunet(x_sample) + # ----------- + # Denoise original image with DRUNet + noise_level = 10 + x_sample = 0.5*(x[0,0,:,:] + 1).cpu().numpy() + imagesc(x_sample ,'Ground-truth image normalized', show=False) + + x_sample = uint2single(255*x_sample) + x_sample = single2tensor4(x_sample[:,:,np.newaxis]) + x_sample = torch.cat((x_sample, torch.FloatTensor([noise_level/255.]).repeat(1, 1, x_sample.shape[2], x_sample.shape[3])), dim=1) + x_sample = x_sample.to(device) + z_den_drunet = denoi_drunet(x_sample) # Plots x_plot = x.view(-1,H,H).cpu().numpy() From 4b2b5021c3a9cc0bb527a29c1ef3745c320c1ed9 Mon Sep 17 00:00:00 2001 From: jabascal Date: Wed, 12 Jul 2023 11:59:05 +0200 Subject: [PATCH 5/6] Update tuto_core_2d_short.ipynb: DCNET with UNet denoising. Updated links to branch and parameters to run on colab --- spyrit/tutorial/tuto_core_2d_short.ipynb | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/spyrit/tutorial/tuto_core_2d_short.ipynb b/spyrit/tutorial/tuto_core_2d_short.ipynb index 325faec4..137ece26 100644 --- a/spyrit/tutorial/tuto_core_2d_short.ipynb +++ b/spyrit/tutorial/tuto_core_2d_short.ipynb @@ -5,7 +5,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openspyrit/spyrit/blob/demo_colab/spyrit/tutorial/tuto_core_2d_short.ipynb)" + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openspyrit/spyrit/blob/master/spyrit/tutorial/tuto_core_2d_short.ipynb)" ] }, { @@ -67,7 +67,7 @@ "source": [ "# Set download data covariance to True for realistic simulations\n", "# It taken a few minutes to download the data\n", - "download_cov = False" + "download_cov = True" ] }, { @@ -96,7 +96,7 @@ }, "outputs": [], "source": [ - "mode_colab = False\n", + "mode_colab = True\n", "if (mode_colab is True):\n", " # Connect to googledrive\n", " #if 'google.colab' in str(get_ipython()):\n", @@ -172,7 +172,7 @@ "outputs": [], "source": [ "# #%%capture\n", - "install_spyrit = False\n", + "install_spyrit = True\n", "if (mode_colab is True):\n", " if install_spyrit is True:\n", " # Clone and install\n", @@ -182,7 +182,6 @@ "\n", " # Checkout to ongoing branch\n", " !git fetch --all\n", - " !git checkout demo_colab\n", " !pip install girder_client\n", " else:\n", " # cd to spyrit folder is already cloned in your drive\n", @@ -388,7 +387,7 @@ "Data simulation in spyrit is done by using three operators from using *spyrit.core.meas*: image normalization, split measurements and noise perturbation. In the example below, this corresponds to the following steps:\n", "\n", "$$\n", - "x \\xrightarrow[\\text{Step 1}]{\\text{NoNoise}} \\frac{x+1}{2} \\xrightarrow[\\text{Step 2}]{\\text{HadamSplit}} Px \\xrightarrow[\\text{Step 3}]{\\text{Poisson}} y\n", + "x \\xrightarrow[\\text{Step 1}]{\\text{NoNoise}} \\frac{x+1}{2} \\xrightarrow[\\text{Step 2}]{\\text{HadamSplit}} y=Px \\xrightarrow[\\text{Step 3}]{\\text{Poisson}} \\mathcal{P}(\\alpha y)\n", "$$\n", "\n", "- Step 1: Given an image $x$ between $[-1, 1]$, the image is first normalized between $[0, 1]$ as\n", @@ -417,7 +416,7 @@ "- Step 3: Data is finally perturbed by Poisson noise as\n", "\n", "$$\n", - "y = \\mathcal{P}(\\alpha Px)\n", + "\\tilde{y} = \\mathcal{P}(\\alpha y)\n", "$$\n", "\n", "using spirit's *spyrit.core.noise.Poisson* :\n", From 143dbd333eb9646f9773166c6dd836a4cf29925b Mon Sep 17 00:00:00 2001 From: jabascal Date: Wed, 12 Jul 2023 12:10:03 +0200 Subject: [PATCH 6/6] Renamed file --- spyrit/tutorial/tuto_train_colab.ipynb | 304 +++++++++++++++++++++++++ 1 file changed, 304 insertions(+) create mode 100644 spyrit/tutorial/tuto_train_colab.ipynb diff --git a/spyrit/tutorial/tuto_train_colab.ipynb b/spyrit/tutorial/tuto_train_colab.ipynb new file mode 100644 index 00000000..63b3b6d0 --- /dev/null +++ b/spyrit/tutorial/tuto_train_colab.ipynb @@ -0,0 +1,304 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openspyrit/spyrit/blob/master/spyrit/tutorial/tuto_train_colab.ipynb)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Tutorial to train a reconstruction network " + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Tutorial to train a reconstruction network for 2D single-pixel imaging on stl10.\n", + "\n", + "Current example trains DCNET (data completion with UNet denoising with 0.5 M trainable parameters). " + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Settings and requirements" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import datetime" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, mount google drive to import modules spyrit modules." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Set google colab" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mode_colab = True\n", + "if (mode_colab is True):\n", + " # Connect to googledrive\n", + " #if 'google.colab' in str(get_ipython()):\n", + " # Mount google drive to access files via colab\n", + " from google.colab import drive\n", + " drive.mount(\"/content/gdrive\")\n", + " %cd /content/gdrive/MyDrive/\n", + "\n", + " # For the profiler\n", + " !pip install -U tensorboard-plugin-profile\n", + "\n", + " # Load the TensorBoard notebook extension\n", + " %load_ext tensorboard" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "On colab, hoose GPU at *Runtime/Change runtime type*" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!nvidia-smi" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Clone Spyrit package" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Clone and install spyrit package if not installedClone and install spyrit package if not installed or move to spyrit folder" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "install_spyrit = True\n", + "if (mode_colab is True):\n", + " if install_spyrit is True:\n", + " # Clone and install\n", + " !git clone https://github.com/openspyrit/spyrit.git\n", + " %cd spyrit\n", + " !pip install -e .\n", + "\n", + " # Checkout to ongoing branch\n", + " !git fetch --all\n", + " else:\n", + " # cd to spyrit folder is already cloned in your drive\n", + " %cd /content/gdrive/MyDrive/Colab_Notebooks/openspyrit/spyrit\n", + "\n", + " # Add paths for modules\n", + " import sys\n", + " sys.path.append('./spyrit/core')\n", + " sys.path.append('./spyrit/misc')\n", + " sys.path.append('./spyrit/tutorial')\n", + "else:\n", + " # Change path to spyrit/\n", + " os.chdir('../..')\n", + " !pwd" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Download data" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Download covariance matrix. Alternatively install *openspyrit/spas* package:\n", + "```\n", + "├───stats\n", + "│ ├───Average_64x64.npy\n", + "│ ├───Cov_64x64.npy\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "download_cov = True\n", + "if (download_cov is True):\n", + " !pip install girder-client\n", + " import girder_client\n", + "\n", + " # api Rest url of the warehouse\n", + " url='https://pilot-warehouse.creatis.insa-lyon.fr/api/v1'\n", + " \n", + " # Generate the warehouse client\n", + " gc = girder_client.GirderClient(apiUrl=url)\n", + "\n", + " #%% Download the covariance matrix and mean image\n", + " data_folder = './stat/'\n", + " dataId_list = [\n", + " '63935b624d15dd536f0484a5', # for reconstruction (imageNet, 64)\n", + " '63935a224d15dd536f048496', # for reconstruction (imageNet, 64)\n", + " ]\n", + " for dataId in dataId_list:\n", + " myfile = gc.getFile(dataId)\n", + " gc.downloadFile(dataId, data_folder + myfile['name'])\n", + "\n", + " print(f'Created {data_folder}') \n", + " !ls $data_folder" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Perturbed by Poisson noise (100 photons) and undersampling factor of 4, on stl10 dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Parameters\n", + "N0 = 100\n", + "M = 1024\n", + "data_root = './data/'\n", + "data = 'stl10'\n", + "stat_root = './stat'\n", + "now = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')\n", + "tb_path = f'runs/runs_stdl10_n100_m1024/{now}' # None\n", + "tb_prof = True # False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Run tuto_train\n", + "if (mode_colab is True):\n", + " # Copy tuto_train.py to main directory for colab\n", + " !pwd\n", + " !cp spyrit/tutorial/train.py .\n", + " !python3 train.py --N0 $N0 --M $M --data_root $data_root --data $data --stat_root $stat_root --tb_path $tb_path --tb_prof $tb_prof\n", + " !rm train.py\n", + "else:\n", + " !python3 spyrit/tutorial/train.py --N0 $N0 --M $M --data_root $data_root --data $data --stat_root $stat_root --tb_path $tb_path --tb_prof $tb_prof" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluate the trained model" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Tensorboard" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Launch TensorBoard\n", + "# %tensorboard --logdir $tb_path\n", + "%tensorboard --logdir runs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# If run twice tensorboard\n", + "#!lsof -i:6006\n", + "#!kill -9 17387" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "spy", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.3" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +}