-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
304 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,304 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openspyrit/spyrit/blob/master/spyrit/tutorial/tuto_train_colab.ipynb)" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Tutorial to train a reconstruction network " | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Tutorial to train a reconstruction network for 2D single-pixel imaging on stl10.\n", | ||
"\n", | ||
"Current example trains DCNET (data completion with UNet denoising with 0.5 M trainable parameters). " | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Settings and requirements" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"import datetime" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"First, mount google drive to import modules spyrit modules." | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Set google colab" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"mode_colab = True\n", | ||
"if (mode_colab is True):\n", | ||
" # Connect to googledrive\n", | ||
" #if 'google.colab' in str(get_ipython()):\n", | ||
" # Mount google drive to access files via colab\n", | ||
" from google.colab import drive\n", | ||
" drive.mount(\"/content/gdrive\")\n", | ||
" %cd /content/gdrive/MyDrive/\n", | ||
"\n", | ||
" # For the profiler\n", | ||
" !pip install -U tensorboard-plugin-profile\n", | ||
"\n", | ||
" # Load the TensorBoard notebook extension\n", | ||
" %load_ext tensorboard" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"On colab, hoose GPU at *Runtime/Change runtime type*" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!nvidia-smi" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Clone Spyrit package" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Clone and install spyrit package if not installedClone and install spyrit package if not installed or move to spyrit folder" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"install_spyrit = True\n", | ||
"if (mode_colab is True):\n", | ||
" if install_spyrit is True:\n", | ||
" # Clone and install\n", | ||
" !git clone https://github.com/openspyrit/spyrit.git\n", | ||
" %cd spyrit\n", | ||
" !pip install -e .\n", | ||
"\n", | ||
" # Checkout to ongoing branch\n", | ||
" !git fetch --all\n", | ||
" else:\n", | ||
" # cd to spyrit folder is already cloned in your drive\n", | ||
" %cd /content/gdrive/MyDrive/Colab_Notebooks/openspyrit/spyrit\n", | ||
"\n", | ||
" # Add paths for modules\n", | ||
" import sys\n", | ||
" sys.path.append('./spyrit/core')\n", | ||
" sys.path.append('./spyrit/misc')\n", | ||
" sys.path.append('./spyrit/tutorial')\n", | ||
"else:\n", | ||
" # Change path to spyrit/\n", | ||
" os.chdir('../..')\n", | ||
" !pwd" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Download data" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Download covariance matrix. Alternatively install *openspyrit/spas* package:\n", | ||
"```\n", | ||
"├───stats\n", | ||
"│ ├───Average_64x64.npy\n", | ||
"│ ├───Cov_64x64.npy\n", | ||
"```" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"download_cov = True\n", | ||
"if (download_cov is True):\n", | ||
" !pip install girder-client\n", | ||
" import girder_client\n", | ||
"\n", | ||
" # api Rest url of the warehouse\n", | ||
" url='https://pilot-warehouse.creatis.insa-lyon.fr/api/v1'\n", | ||
" \n", | ||
" # Generate the warehouse client\n", | ||
" gc = girder_client.GirderClient(apiUrl=url)\n", | ||
"\n", | ||
" #%% Download the covariance matrix and mean image\n", | ||
" data_folder = './stat/'\n", | ||
" dataId_list = [\n", | ||
" '63935b624d15dd536f0484a5', # for reconstruction (imageNet, 64)\n", | ||
" '63935a224d15dd536f048496', # for reconstruction (imageNet, 64)\n", | ||
" ]\n", | ||
" for dataId in dataId_list:\n", | ||
" myfile = gc.getFile(dataId)\n", | ||
" gc.downloadFile(dataId, data_folder + myfile['name'])\n", | ||
"\n", | ||
" print(f'Created {data_folder}') \n", | ||
" !ls $data_folder" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Train" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Perturbed by Poisson noise (100 photons) and undersampling factor of 4, on stl10 dataset" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Parameters\n", | ||
"N0 = 100\n", | ||
"M = 1024\n", | ||
"data_root = './data/'\n", | ||
"data = 'stl10'\n", | ||
"stat_root = './stat'\n", | ||
"now = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')\n", | ||
"tb_path = f'runs/runs_stdl10_n100_m1024/{now}' # None\n", | ||
"tb_prof = True # False" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Run tuto_train\n", | ||
"if (mode_colab is True):\n", | ||
" # Copy tuto_train.py to main directory for colab\n", | ||
" !pwd\n", | ||
" !cp spyrit/tutorial/train.py .\n", | ||
" !python3 train.py --N0 $N0 --M $M --data_root $data_root --data $data --stat_root $stat_root --tb_path $tb_path --tb_prof $tb_prof\n", | ||
" !rm train.py\n", | ||
"else:\n", | ||
" !python3 spyrit/tutorial/train.py --N0 $N0 --M $M --data_root $data_root --data $data --stat_root $stat_root --tb_path $tb_path --tb_prof $tb_prof" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Evaluate the trained model" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Tensorboard" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Launch TensorBoard\n", | ||
"# %tensorboard --logdir $tb_path\n", | ||
"%tensorboard --logdir runs" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# If run twice tensorboard\n", | ||
"#!lsof -i:6006\n", | ||
"#!kill -9 17387" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "spy", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"name": "python", | ||
"version": "3.11.3" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |