Skip to content

Commit

Permalink
Renamed file
Browse files Browse the repository at this point in the history
  • Loading branch information
jabascal committed Jul 12, 2023
1 parent 4b2b502 commit 143dbd3
Showing 1 changed file with 304 additions and 0 deletions.
304 changes: 304 additions & 0 deletions spyrit/tutorial/tuto_train_colab.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openspyrit/spyrit/blob/master/spyrit/tutorial/tuto_train_colab.ipynb)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tutorial to train a reconstruction network "
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Tutorial to train a reconstruction network for 2D single-pixel imaging on stl10.\n",
"\n",
"Current example trains DCNET (data completion with UNet denoising with 0.5 M trainable parameters). "
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Settings and requirements"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import datetime"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"First, mount google drive to import modules spyrit modules."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Set google colab"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mode_colab = True\n",
"if (mode_colab is True):\n",
" # Connect to googledrive\n",
" #if 'google.colab' in str(get_ipython()):\n",
" # Mount google drive to access files via colab\n",
" from google.colab import drive\n",
" drive.mount(\"/content/gdrive\")\n",
" %cd /content/gdrive/MyDrive/\n",
"\n",
" # For the profiler\n",
" !pip install -U tensorboard-plugin-profile\n",
"\n",
" # Load the TensorBoard notebook extension\n",
" %load_ext tensorboard"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"On colab, hoose GPU at *Runtime/Change runtime type*"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!nvidia-smi"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Clone Spyrit package"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Clone and install spyrit package if not installedClone and install spyrit package if not installed or move to spyrit folder"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"install_spyrit = True\n",
"if (mode_colab is True):\n",
" if install_spyrit is True:\n",
" # Clone and install\n",
" !git clone https://github.com/openspyrit/spyrit.git\n",
" %cd spyrit\n",
" !pip install -e .\n",
"\n",
" # Checkout to ongoing branch\n",
" !git fetch --all\n",
" else:\n",
" # cd to spyrit folder is already cloned in your drive\n",
" %cd /content/gdrive/MyDrive/Colab_Notebooks/openspyrit/spyrit\n",
"\n",
" # Add paths for modules\n",
" import sys\n",
" sys.path.append('./spyrit/core')\n",
" sys.path.append('./spyrit/misc')\n",
" sys.path.append('./spyrit/tutorial')\n",
"else:\n",
" # Change path to spyrit/\n",
" os.chdir('../..')\n",
" !pwd"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Download data"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Download covariance matrix. Alternatively install *openspyrit/spas* package:\n",
"```\n",
"├───stats\n",
"│ ├───Average_64x64.npy\n",
"│ ├───Cov_64x64.npy\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"download_cov = True\n",
"if (download_cov is True):\n",
" !pip install girder-client\n",
" import girder_client\n",
"\n",
" # api Rest url of the warehouse\n",
" url='https://pilot-warehouse.creatis.insa-lyon.fr/api/v1'\n",
" \n",
" # Generate the warehouse client\n",
" gc = girder_client.GirderClient(apiUrl=url)\n",
"\n",
" #%% Download the covariance matrix and mean image\n",
" data_folder = './stat/'\n",
" dataId_list = [\n",
" '63935b624d15dd536f0484a5', # for reconstruction (imageNet, 64)\n",
" '63935a224d15dd536f048496', # for reconstruction (imageNet, 64)\n",
" ]\n",
" for dataId in dataId_list:\n",
" myfile = gc.getFile(dataId)\n",
" gc.downloadFile(dataId, data_folder + myfile['name'])\n",
"\n",
" print(f'Created {data_folder}') \n",
" !ls $data_folder"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Perturbed by Poisson noise (100 photons) and undersampling factor of 4, on stl10 dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Parameters\n",
"N0 = 100\n",
"M = 1024\n",
"data_root = './data/'\n",
"data = 'stl10'\n",
"stat_root = './stat'\n",
"now = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')\n",
"tb_path = f'runs/runs_stdl10_n100_m1024/{now}' # None\n",
"tb_prof = True # False"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Run tuto_train\n",
"if (mode_colab is True):\n",
" # Copy tuto_train.py to main directory for colab\n",
" !pwd\n",
" !cp spyrit/tutorial/train.py .\n",
" !python3 train.py --N0 $N0 --M $M --data_root $data_root --data $data --stat_root $stat_root --tb_path $tb_path --tb_prof $tb_prof\n",
" !rm train.py\n",
"else:\n",
" !python3 spyrit/tutorial/train.py --N0 $N0 --M $M --data_root $data_root --data $data --stat_root $stat_root --tb_path $tb_path --tb_prof $tb_prof"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluate the trained model"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Tensorboard"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Launch TensorBoard\n",
"# %tensorboard --logdir $tb_path\n",
"%tensorboard --logdir runs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# If run twice tensorboard\n",
"#!lsof -i:6006\n",
"#!kill -9 17387"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "spy",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.11.3"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 143dbd3

Please sign in to comment.