Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tutorial to train a reconstruction network: Network can be selected #107

Merged
merged 1 commit into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions spyrit/tutorial/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
parser.add_argument("--stat_root", type=str, default="./stat/", help="Path to precomputed data")
parser.add_argument("--arch", type=str, default="dc-net", help="Choose among 'dc-net','pinv-net',")
parser.add_argument("--denoi", type=str, default="unet", help="Choose among 'cnn','cnnbn', 'unet'")
parser.add_argument("--device", type=str, default="", help="Choose among 'cuda','cpu'")
#parser.add_argument("--no_denoi", default=False, action='store_true', help="No denoising layer")


Expand All @@ -68,19 +69,17 @@
opt.model_root = Path(opt.model_root)
opt.data_root = Path(opt.data_root)

#if opt.data == 'stl10':
# opt.data_root = '../../../data/'

opt.arch = "upgd"

print(opt)

#==========================================================================
# 0. Setting up parameters for training
#==========================================================================
# The device of the machine, number of workers...
#
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if opt.device:
device = torch.device(opt.device)
else:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Device: {device}')

#==========================================================================
Expand Down
131 changes: 131 additions & 0 deletions spyrit/tutorial/tuto_core_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
r"""
04. Tutorial 2D - Tutorial to train a reconstruction network
======================
This tutorial shows how to train a reconstruction network for 2D single-pixel imaging
on stl10. Training is performed by a call to *train.py*. Several parameters allow
to modify acquisition, network and training (network architecture),
optimisation and the use of tensorboard.

Currently you can train the following networks by modifying the network architecture variable *arch*:

- 'dc-net': Denoised Completion Network (DCNet).
- 'pinv-net': Pseudo Inverse Network (PinvNet).
- 'upgd': Unrolled proximal gradient descent (UPGD).

and the denoising variable *denoi*: E
- 'cnn': CNN no batch normalization
- 'cnnbn': CNN with batch normalization
- 'unet': UNet (0.5 M trainable parameters)


[Colab version]: https://colab.research.google.com/github/openspyrit/spyrit/blob/master/spyrit/tutorial/tuto_train_colab.ipynb

"""

###############################################################################
# Import packages
import os
import datetime
import subprocess

###############################################################################
# Download covariance matrix. Alternatively install *openspyrit/spas* package:
#
# spyrit
# ├───stat
# │ ├───Average_64x64.npy
# │ ├───Cov_64x64.npy
# ├───spirit
#

download_cov = True
if (download_cov is True):
# pip install girder-client
import girder_client

# api Rest url of the warehouse
url='https://pilot-warehouse.creatis.insa-lyon.fr/api/v1'

# Generate the warehouse client
gc = girder_client.GirderClient(apiUrl=url)

# Download the covariance matrix and mean image
data_folder = './stat/'
dataId_list = [
'63935b624d15dd536f0484a5', # for reconstruction (imageNet, 64)
'63935a224d15dd536f048496', # for reconstruction (imageNet, 64)
]
for dataId in dataId_list:
myfile = gc.getFile(dataId)
gc.downloadFile(dataId, data_folder + myfile['name'])

print(f'Created {data_folder}')
os.listdir(data_folder)

###############################################################################
r""" Train
You can choose the following parameters:

- Acquisition:
- --img_size: Height / width dimension, default=64
- --M: Number of undersampling patterns, default=512
- --subs: Among 'var','rect', default="var"

- Network and training:
- --data: stl10 or imagenet, default="stl10"
- --model_root: Path to model saving files, default='./model/'
- --data_root: Path to the dataset, default="./data/"

- --N0: Mean maximum total number of photons, default=10
- --stat_root: Path to precomputed data, default="./stat/"
- --arch: Choose among 'dc-net','pinv-net', 'upgd', default="dc-net"
- --denoi: Choose among 'cnn','cnnbn', 'unet', default="unet"
- --device", Choose among 'cuda','cpu', default="" (cuda if available)

- Optimisation:
- --num_epochs: Number of training epochs, default=30
- --batch_size: Size of each training batch, default=512
- --reg: Regularisation Parameter, default=1e-7
- --step_size: Scheduler Step Size, default=10
- --gamma: Scheduler Decrease Rate, default=0.5
- --checkpoint_model: Optional path to checkpoint model, default=""
- --checkpoint_interval: Interval between saving model checkpoints, default=0
- Training is done with *Adam* optimizer, *MSELoss*

- Tensorboard:
- --tb_path: Relative path for Tensorboard experiment tracking logs, default=False
- --tb_prof: Code profiler with Tensorboard, default=False
- Logging of scalars *train_loss*, *val_loss* and images (dataset example ground-truth and predictions at different epochs).
"""

###############################################################################
# In this tutorial, data is perturbed by Poisson noise (100 mean photons)
# and undersampling factor of 4, on stl10 dataset.
# Training is done with default parameters (see above) and using experiment tracking with tensorboard.

# Parameters
N0 = 100
M = 1024
data_root = './data/'
data = 'stl10'
stat_root = './stat'

# Tensorboard logs path
now = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')
tb_path = f'runs/runs_stdl10_n100_m1024/{now}'
tb_prof = True # False

# Run train.py
cmd = ['python', 'spyrit/tutorial/train.py', '--N0', str(N0), '--M', str(M), '--data_root', str(data_root),
'--data', str(data), '--stat_root', str(stat_root), '--tb_path', str(tb_path), '--tb_prof', str(tb_prof),
'--device', str('cpu')]
subprocess.run(cmd, check=True)

###############################################################################
# Tensorboard
#
# To launch tensorboard, run in a terminal:
# tensorboard --logdir <tb_path>
#
# Select *SCALARS* or *IMAGES*. More options are available in the top-right corner.