diff --git a/README.md b/README.md
index e7ca987..8932448 100644
--- a/README.md
+++ b/README.md
@@ -45,6 +45,13 @@
+## MeshNet Example
+This basic example provides an overview of the training pipeline for the MeshNet model.
+
+* [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/neuroneural/brainchop/blob/master/py2tfjs/MeshNet_Training_Example.ipynb) [MeshNet Basic Training example](./py2tfjs/MeshNet_Training_Example.ipynb)
+
+
+
## Live Demo
To see Brainchop in action please click [here](https://neuroneural.github.io/brainchop).
diff --git a/py2tfjs/MeshNet_Training_Example.ipynb b/py2tfjs/MeshNet_Training_Example.ipynb
new file mode 100644
index 0000000..d4323e5
--- /dev/null
+++ b/py2tfjs/MeshNet_Training_Example.ipynb
@@ -0,0 +1,3231 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "source": [
+ ""
+ ],
+ "metadata": {
+ "id": "A1c2bimsUbI2"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "In this tutorial, you will find a simple tutorial on how to train the **MeshNet** model for MRI brain Gray Matter White Matter (GWM) segmentation. This task involves segmenting the brain into three different regions. The model will be trained on sample volumes from the **Mindboggle 101** brain MRI scans dataset for the multiclass 3D segmentation task.\n",
+ "\n",
+ "\n",
+ "This training pipeline example is part of the [**Brainchop**](https://neuroneural.github.io/brainchop/) project, where the basic MeshNet model is trained using **PyTorch**, and the resulting model can be converted to the **Tensorflow.js** (tfjs) model to be used with Brainchop.\n",
+ "\n",
+ "For more information about the whole conversion process, please refer to the Repo [Wiki](https://github.com/neuroneural/brainchop/wiki).\n",
+ "\n",
+ "---\n",
+ "\n",
+ "This tutorial developed by [Pratyush Reddy](pratyushrg@gmail.com\n",
+ "), revised by [Mohamed Masoud](mohamedemory@gmail.com), and [Sergey Plis](s.m.plis@gmail.com)\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "tKyWAoiv3mN3"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Oxp9pcTthFZW"
+ },
+ "source": [
+ "#Imports"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "* Installing nibabel and nilearn packages for reading 3d Brain MRI scans.\n",
+ "* Importing essential libraries"
+ ],
+ "metadata": {
+ "id": "BKmrYuJk1NpZ"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "WPRBNRuVJAbX",
+ "outputId": "d5152b59-1cba-47a3-f322-a44ea9fc2ff5"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Requirement already satisfied: nibabel in /usr/local/lib/python3.10/dist-packages (4.0.2)\n",
+ "Collecting nilearn\n",
+ " Downloading nilearn-0.10.2-py3-none-any.whl (10.4 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.4/10.4 MB\u001b[0m \u001b[31m79.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from nibabel) (1.23.5)\n",
+ "Requirement already satisfied: packaging>=17.0 in /usr/local/lib/python3.10/dist-packages (from nibabel) (23.2)\n",
+ "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from nibabel) (67.7.2)\n",
+ "Requirement already satisfied: joblib>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from nilearn) (1.3.2)\n",
+ "Requirement already satisfied: lxml in /usr/local/lib/python3.10/dist-packages (from nilearn) (4.9.3)\n",
+ "Requirement already satisfied: pandas>=1.1.5 in /usr/local/lib/python3.10/dist-packages (from nilearn) (1.5.3)\n",
+ "Requirement already satisfied: requests>=2.25.0 in /usr/local/lib/python3.10/dist-packages (from nilearn) (2.31.0)\n",
+ "Requirement already satisfied: scikit-learn>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from nilearn) (1.2.2)\n",
+ "Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from nilearn) (1.11.3)\n",
+ "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.1.5->nilearn) (2.8.2)\n",
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.1.5->nilearn) (2023.3.post1)\n",
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.25.0->nilearn) (3.3.2)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.25.0->nilearn) (3.4)\n",
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.25.0->nilearn) (2.0.7)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.25.0->nilearn) (2023.7.22)\n",
+ "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=1.0.0->nilearn) (3.2.0)\n",
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas>=1.1.5->nilearn) (1.16.0)\n",
+ "Installing collected packages: nilearn\n",
+ "Successfully installed nilearn-0.10.2\n"
+ ]
+ }
+ ],
+ "source": [
+ "!pip install nibabel nilearn"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "id": "TNueE888JCQp"
+ },
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import nibabel as nib\n",
+ "import ipywidgets as widgets\n",
+ "from nilearn import plotting\n",
+ "import matplotlib.pyplot as plt\n",
+ "from IPython.display import display\n",
+ "from collections import OrderedDict\n",
+ "import torch\n",
+ "import os\n",
+ "import torch.nn as nn\n",
+ "from torch.utils.checkpoint import checkpoint_sequential\n",
+ "from torch.utils.data import DataLoader"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "5VYNwjgLSwLd"
+ },
+ "source": [
+ "# Download required dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "We'll be using samples from Mindboggle 101 brain MRI scans dataset for a multiclass 3d segmentation task.\n",
+ "* Contents:\n",
+ " 1. BrainDatasets comprise of T1 scans + the prepared limited labels.\n",
+ " 2. Training(10 pairs of images,labels), Validation(2 pairs), Inference(3 pairs) datasets"
+ ],
+ "metadata": {
+ "id": "y99spj1F2Zt-"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "id": "uBBYs8LiBvRs",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "0e63a213-802c-41c2-a571-5b7f9730068a"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "--2023-11-27 06:30:53-- https://meshnet-pr-dataset.s3.amazonaws.com/data-1-10.zip\n",
+ "Resolving meshnet-pr-dataset.s3.amazonaws.com (meshnet-pr-dataset.s3.amazonaws.com)... 54.231.200.33, 52.217.196.65, 52.216.24.68, ...\n",
+ "Connecting to meshnet-pr-dataset.s3.amazonaws.com (meshnet-pr-dataset.s3.amazonaws.com)|54.231.200.33|:443... connected.\n",
+ "HTTP request sent, awaiting response... 200 OK\n",
+ "Length: 128402177 (122M) [application/zip]\n",
+ "Saving to: ‘data-1-10.zip’\n",
+ "\n",
+ "data-1-10.zip 100%[===================>] 122.45M 34.8MB/s in 3.5s \n",
+ "\n",
+ "2023-11-27 06:30:56 (34.8 MB/s) - ‘data-1-10.zip’ saved [128402177/128402177]\n",
+ "\n",
+ "Archive: data-1-10.zip\n",
+ " inflating: data/coords_generator.py \n",
+ " inflating: data/dataset_infer.csv \n",
+ " inflating: data/dataset_train.csv \n",
+ " inflating: data/model.py \n",
+ " inflating: data/brain_dataset(1).py \n",
+ " inflating: data/reader.py \n",
+ " inflating: data/brain_dataset.py \n",
+ " inflating: data/dataset_valid.csv \n",
+ " inflating: data/Mindboggle_101/NKI-TRT-20_volumes/NKI-TRT-20-18/t1weighted.nii.gz \n",
+ " inflating: data/Mindboggle_101/MMRR-21_volumes/MMRR-21-9/t1weighted.nii.gz \n",
+ " inflating: data/Mindboggle_101/MMRR-21_volumes/MMRR-21-12/t1weighted.nii.gz \n",
+ " inflating: data/Mindboggle_101/OASIS-TRT-20_volumes/OASIS-TRT-20-13/t1weighted.nii.gz \n",
+ " inflating: data/Mindboggle_101/MMRR-21_volumes/MMRR-21-7/t1weighted.nii.gz \n",
+ " inflating: data/Mindboggle_101/MMRR-21_volumes/MMRR-21-3/labels.DKT31.manual+aseg.nii.gz \n",
+ " inflating: data/Mindboggle_101/NKI-TRT-20_volumes/NKI-TRT-20-6/labels.DKT31.manual+aseg.nii.gz \n",
+ " inflating: data/Mindboggle_101/Extra-18_volumes/MMRR-3T7T-2-1/labels.DKT31.manual+aseg.nii.gz \n",
+ " inflating: data/Mindboggle_101/NKI-RS-22_volumes/NKI-RS-22-15/labels.DKT31.manual+aseg.nii.gz \n",
+ " inflating: data/Mindboggle_101/NKI-TRT-20_volumes/NKI-TRT-20-11/t1weighted.nii.gz \n",
+ " inflating: data/Mindboggle_101/NKI-TRT-20_volumes/NKI-TRT-20-6/t1weighted.nii.gz \n",
+ " inflating: data/Mindboggle_101/Extra-18_volumes/MMRR-3T7T-2-1/t1weighted.nii.gz \n",
+ " inflating: data/Mindboggle_101/NKI-RS-22_volumes/NKI-RS-22-19/labels.DKT31.manual+aseg.nii.gz \n",
+ " inflating: data/Mindboggle_101/NKI-RS-22_volumes/NKI-RS-22-8/t1weighted.nii.gz \n",
+ " inflating: data/Mindboggle_101/MMRR-21_volumes/MMRR-21-7/labels.DKT31.manual+aseg.nii.gz \n",
+ " inflating: data/Mindboggle_101/MMRR-21_volumes/MMRR-21-9/labels.DKT31.manual+aseg.nii.gz \n",
+ " inflating: data/Mindboggle_101/NKI-TRT-20_volumes/NKI-TRT-20-18/labels.DKT31.manual+aseg.nii.gz \n",
+ " inflating: data/Mindboggle_101/NKI-RS-22_volumes/NKI-RS-22-5/labels.DKT31.manual+aseg.nii.gz \n",
+ " inflating: data/Mindboggle_101/NKI-RS-22_volumes/NKI-RS-22-9/labels.DKT31.manual+aseg.nii.gz \n",
+ " inflating: data/Mindboggle_101/NKI-RS-22_volumes/NKI-RS-22-19/t1weighted.nii.gz \n",
+ " inflating: data/Mindboggle_101/Extra-18_volumes/HLN-12-12/t1weighted.nii.gz \n",
+ " inflating: data/Mindboggle_101/Extra-18_volumes/HLN-12-12/labels.DKT31.manual+aseg.nii.gz \n",
+ " inflating: data/Mindboggle_101/MMRR-21_volumes/MMRR-21-3/t1weighted.nii.gz \n",
+ " inflating: data/Mindboggle_101/OASIS-TRT-20_volumes/OASIS-TRT-20-13/labels.DKT31.manual+aseg.nii.gz \n",
+ " inflating: data/Mindboggle_101/MMRR-21_volumes/MMRR-21-12/labels.DKT31.manual+aseg.nii.gz \n",
+ " inflating: data/Mindboggle_101/NKI-RS-22_volumes/NKI-RS-22-8/labels.DKT31.manual+aseg.nii.gz \n",
+ " inflating: data/Mindboggle_101/NKI-RS-22_volumes/NKI-RS-22-5/t1weighted.nii.gz \n",
+ " inflating: data/Mindboggle_101/NKI-RS-22_volumes/NKI-RS-22-9/t1weighted.nii.gz \n",
+ " inflating: data/Mindboggle_101/NKI-RS-22_volumes/NKI-RS-22-15/t1weighted.nii.gz \n",
+ " inflating: data/Mindboggle_101/NKI-TRT-20_volumes/NKI-TRT-20-11/labels.DKT31.manual+aseg.nii.gz \n"
+ ]
+ }
+ ],
+ "source": [
+ "!wget https://meshnet-pr-dataset.s3.amazonaws.com/data-1-10.zip\n",
+ "!unzip data-1-10.zip\n",
+ "!rm data-1-10.zip"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "slH_hecyS9J8"
+ },
+ "source": [
+ "# Generate graywhite and anatomic labels from existing label files"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "id": "RF1RJ8o-An3X",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 35
+ },
+ "outputId": "77aa823a-7b73-498f-85f6-da7bdfebd467"
+ },
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "'/content'"
+ ],
+ "application/vnd.google.colaboratory.intrinsic+json": {
+ "type": "string"
+ }
+ },
+ "metadata": {},
+ "execution_count": 4
+ }
+ ],
+ "source": [
+ "pwd"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "This function takes an input label array and produces a binary mask that distinguishes white matter and gray matter regions using specific integer codes. The conversion to comply with FreeSurfer label numbersing convention and their ColorLUT:\n",
+ "\n",
+ "https://surfer.nmr.mgh.harvard.edu/fswiki/FsTutorial/AnatomicalROI/FreeSurferColorLUT"
+ ],
+ "metadata": {
+ "id": "Z1NBHcjw6_du"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "id": "8jtDcxjUHLFL"
+ },
+ "outputs": [],
+ "source": [
+ "def labels2graywhite(label):\n",
+ " white_code = [2, 41, 7, 16, 46] + [*range(251, 256)]\n",
+ " gray_code = (\n",
+ " [*range(1001, 1004)]\n",
+ " + [*range(1005, 1036)]\n",
+ " + [*range(2001, 2004)]\n",
+ " + [*range(2005, 2036)]\n",
+ " + [*range(8, 14)]\n",
+ " + [*range(17, 21)]\n",
+ " + [*range(26, 29)]\n",
+ " + [*range(47, 56)]\n",
+ " + [*range(58, 61)]\n",
+ " )\n",
+ " white = np.isin(label, white_code)\n",
+ " gray = np.isin(label, gray_code)\n",
+ " return white.astype(np.uint8) + (2 * gray).astype(np.uint8)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "This function takes an input label array and maps the labels to anatomical regions"
+ ],
+ "metadata": {
+ "id": "A-k0A8VY7c0G"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "id": "tu11Qg88He3s"
+ },
+ "outputs": [],
+ "source": [
+ "def labels2anatomic(label):\n",
+ " source = (\n",
+ " [0]\n",
+ " + [i for i in range(1001, 1004)]\n",
+ " + [i for i in range(1005, 1036)]\n",
+ " + [i for i in range(2001, 2004)]\n",
+ " + [i for i in range(2005, 2036)]\n",
+ " + [10,49,11,50,12,51,13,52,17,53,18,54,26,58,28,60,2,41,4,5,43,44,14,15,24,16,7,46,8,47,251,252,253,254,255,]\n",
+ " )\n",
+ " labelmap = {x: idx for idx, x in enumerate(source)}\n",
+ "\n",
+ " @np.vectorize\n",
+ " def relabel(x):\n",
+ " y = 0\n",
+ " if x in labelmap:\n",
+ " y = labelmap[x]\n",
+ " return y\n",
+ " return relabel(label).astype(np.uint8)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "This function is for creating new labels names for labels generated from above **labels2anatomic, labels2graywhite** functions"
+ ],
+ "metadata": {
+ "id": "FuZQdAPK7fk7"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "id": "dAPfH7XZIbXD"
+ },
+ "outputs": [],
+ "source": [
+ "def create_label(label, prefix):\n",
+ " temp=label.split('/')[:-1]\n",
+ " temp.append(prefix+label.split('/')[-1])\n",
+ " return '/'.join(temp)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "function to update dataset.CSV files with new GW labels and anatamic labels details"
+ ],
+ "metadata": {
+ "id": "gco9Pyua70ni"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "id": "yQuXBwKZHbRw"
+ },
+ "outputs": [],
+ "source": [
+ "def update_csv(CSV_file):\n",
+ " data = pd.read_csv(CSV_file)\n",
+ " data['GWlabels']=np.array([create_label(i, 'GW') for i in data['labels']], dtype=object)\n",
+ " # data['ANAlabels']=np.array([create_label(i, 'ANA') for i in data['labels']], dtype=object)\n",
+ " # data.drop(['nii_labels'], inplace=True, axis=1)\n",
+ " data.to_csv(CSV_file, index=False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Function to Create GW labes using **labels2graywhite** function"
+ ],
+ "metadata": {
+ "id": "Z9G_RjV58H1O"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "id": "qP32gW1lOU9l"
+ },
+ "outputs": [],
+ "source": [
+ "def create_greywhite(CSV_file):\n",
+ " data = pd.read_csv(CSV_file)\n",
+ " for label,GWlabel in zip(data.labels,data.GWlabels):\n",
+ " print('/'.join(GWlabel.split('/')[:-1]),GWlabel.split('/')[-1])\n",
+ " img_nifti = nib.load(label)\n",
+ " img = np.array(img_nifti.dataobj)\n",
+ " ni_img = nib.Nifti1Image(labels2graywhite(img), affine=np.eye(4))\n",
+ " nib.save(ni_img, os.path.join('/'.join(GWlabel.split('/')[:-1]), GWlabel.split('/')[-1]))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Function to Create anatomic labes using **labels2anatomic** function"
+ ],
+ "metadata": {
+ "id": "aRr7zh5V8jca"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {
+ "id": "n9s3gj8DQkfT"
+ },
+ "outputs": [],
+ "source": [
+ "def create_ANA(CSV_file):\n",
+ " data = pd.read_csv(CSV_file)\n",
+ " for label,ANAlabel in zip(data.labels,data.ANAlabels):\n",
+ " print('/'.join(ANAlabel.split('/')[:-1]),ANAlabel.split('/')[-1])\n",
+ " img_nifti = nib.load(label)\n",
+ " img = np.array(img_nifti.dataobj)\n",
+ " ni_img = nib.Nifti1Image(labels2anatomic(img), affine=np.eye(4))\n",
+ " nib.save(ni_img, os.path.join('/'.join(ANAlabel.split('/')[:-1]), ANAlabel.split('/')[-1]))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Update dataset.csv's [Train,Infer,Valid) with generated GW and anatomic labels using **update_csv,create_ANA,create_greywhite**"
+ ],
+ "metadata": {
+ "id": "HIfvSlvb8qv4"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {
+ "id": "36yU0UqRRVzq",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "49cc9ecb-ce9c-43d4-d652-f046053b1c01"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "data/Mindboggle_101/MMRR-21_volumes/MMRR-21-7 GWlabels.DKT31.manual+aseg.nii.gz\n",
+ "data/Mindboggle_101/NKI-RS-22_volumes/NKI-RS-22-19 GWlabels.DKT31.manual+aseg.nii.gz\n",
+ "data/Mindboggle_101/NKI-TRT-20_volumes/NKI-TRT-20-18 GWlabels.DKT31.manual+aseg.nii.gz\n",
+ "data/Mindboggle_101/NKI-TRT-20_volumes/NKI-TRT-20-11 GWlabels.DKT31.manual+aseg.nii.gz\n",
+ "data/Mindboggle_101/NKI-RS-22_volumes/NKI-RS-22-8 GWlabels.DKT31.manual+aseg.nii.gz\n",
+ "data/Mindboggle_101/NKI-RS-22_volumes/NKI-RS-22-15 GWlabels.DKT31.manual+aseg.nii.gz\n",
+ "data/Mindboggle_101/OASIS-TRT-20_volumes/OASIS-TRT-20-13 GWlabels.DKT31.manual+aseg.nii.gz\n",
+ "data/Mindboggle_101/NKI-RS-22_volumes/NKI-RS-22-9 GWlabels.DKT31.manual+aseg.nii.gz\n",
+ "data/Mindboggle_101/Extra-18_volumes/HLN-12-12 GWlabels.DKT31.manual+aseg.nii.gz\n",
+ "data/Mindboggle_101/Extra-18_volumes/MMRR-3T7T-2-1 GWlabels.DKT31.manual+aseg.nii.gz\n",
+ "data/Mindboggle_101/NKI-RS-22_volumes/NKI-RS-22-5 GWlabels.DKT31.manual+aseg.nii.gz\n",
+ "data/Mindboggle_101/NKI-TRT-20_volumes/NKI-TRT-20-6 GWlabels.DKT31.manual+aseg.nii.gz\n",
+ "data/Mindboggle_101/MMRR-21_volumes/MMRR-21-9 GWlabels.DKT31.manual+aseg.nii.gz\n",
+ "data/Mindboggle_101/MMRR-21_volumes/MMRR-21-3 GWlabels.DKT31.manual+aseg.nii.gz\n",
+ "data/Mindboggle_101/MMRR-21_volumes/MMRR-21-12 GWlabels.DKT31.manual+aseg.nii.gz\n"
+ ]
+ }
+ ],
+ "source": [
+ "update_csv('./data/dataset_train.csv')\n",
+ "create_greywhite('./data/dataset_train.csv')\n",
+ "# create_ANA('./data/dataset_train.csv')\n",
+ "update_csv('./data/dataset_infer.csv')\n",
+ "create_greywhite('./data/dataset_infer.csv')\n",
+ "# create_ANA('./data/dataset_infer.csv')\n",
+ "update_csv('./data/dataset_valid.csv')\n",
+ "create_greywhite('./data/dataset_valid.csv')\n",
+ "# create_ANA('./data/dataset_valid.csv')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "GmLKx_pOSoLy"
+ },
+ "source": [
+ "# Plotting Functions"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "The code creates interactive sliders to visualize 3D image volumes and labels. It uses the **peek_class_new** class to handle the slider creation and display, and **nibabel** to load and prepare the image and label data for visualization. We use this functionality to visualise the predictions from our trained model"
+ ],
+ "metadata": {
+ "id": "SBC5bVXV9iib"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {
+ "id": "Xvw97YKhnFuU"
+ },
+ "outputs": [],
+ "source": [
+ "# Create x, y, and z coordinate sliders\n",
+ "class peek_class_new:\n",
+ " def __init__(self,scan_data, label_bool):\n",
+ " self.label_bool = label_bool\n",
+ " self.scan_data = scan_data\n",
+ " if self.label_bool:\n",
+ " self.x_slider = widgets.IntSlider(min=0, max=self.scan_data.shape[0]-1, value=(self.scan_data.shape[0]-1)/2, description='X')\n",
+ " self.y_slider = widgets.IntSlider(min=0, max=self.scan_data.shape[1]-1, value=(self.scan_data.shape[1]-1)/2, description='Y')\n",
+ " self.z_slider = widgets.IntSlider(min=0, max=self.scan_data.shape[2]-1, value=(self.scan_data.shape[2]-1)/2, description='Z')\n",
+ " else:\n",
+ " self.x_slider = widgets.IntSlider(min=(self.scan_data.shape[0]-1)*-1, max=0, value=((self.scan_data.shape[0]-1)/2)*-1, description='X')\n",
+ " self.y_slider = widgets.IntSlider(min=(self.scan_data.shape[1]-1)*-1, max=0, value=((self.scan_data.shape[1]-1)/2)*-1, description='Y')\n",
+ " self.z_slider = widgets.IntSlider(min=0, max=self.scan_data.shape[0]-1, value=(self.scan_data.shape[2]-1)/2, description='Z')\n",
+ "\n",
+ " def shape(self):\n",
+ " print(f\" {self.scan_data.shape[0]} {self.scan_data.shape[1]} {self.scan_data.shape[2]}\")\n",
+ "\n",
+ " def update_slices(self, x, y, z):\n",
+ " display_plot = plotting.plot_anat(self.scan_data, cut_coords=(x, y, z)).add_markers(marker_coords=[[x, y, z]])\n",
+ " # Display the plot\n",
+ " plotting.show()\n",
+ "\n",
+ " def plots(self):\n",
+ " # Link the sliders to the update function\n",
+ " widgets.interact(self.update_slices, x=self.x_slider, y=self.y_slider, z=self.z_slider)\n",
+ " # Display the sliders\n",
+ " display(self.x_slider, self.y_slider, self.z_slider)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {
+ "id": "8rhmg4VH3yZg"
+ },
+ "outputs": [],
+ "source": [
+ "img= nib.load('data/Mindboggle_101/MMRR-21_volumes/MMRR-21-7/t1weighted.nii.gz').get_fdata(dtype=np.float32)\n",
+ "volume_shape = [256, 256, 256]\n",
+ "temp= np.zeros(volume_shape)\n",
+ "temp[: img.shape[0], : img.shape[1], : img.shape[2]] = img\n",
+ "image=temp\n",
+ "# Create a NIfTI image object\n",
+ "nifi_image = nib.Nifti1Image(image, affine=np.eye(4)) # Use identity affine matrix for simplicity"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {
+ "id": "0oGF8djN4H0W",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 575,
+ "referenced_widgets": [
+ "777ef069c28b46b19dfc96a948400137",
+ "87e89b5837144563bfefec82ef8354de",
+ "b508d6a5cbb64e0793e937b70304a053",
+ "c68c67da6dac4f8f9d47dafa3b983b72",
+ "61fb6b1b34434f5ea077e588ac883f83",
+ "367074e77fcf406688d1b2b6022e30f4",
+ "d087bd6e8ec54991b1894a61c6a4cd6c",
+ "84931741e8604604add005389f0c5370",
+ "653f4b81298c40248cb284378e2b1ba0",
+ "1c49223d5d3f415ebdc16ca30a5cc0b5",
+ "a3a2daae81564f3194451604ff6e5de6",
+ "dd978f745221426c9a87421b4c295ff6",
+ "ec79da9919e94249bdd079f55fd553ed"
+ ]
+ },
+ "outputId": "ec5ec0ab-2472-4fdc-f1d6-e883338f7725"
+ },
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "interactive(children=(IntSlider(value=127, description='X', max=255), IntSlider(value=127, description='Y', ma…"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "777ef069c28b46b19dfc96a948400137"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "IntSlider(value=127, description='X', max=255)"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "87e89b5837144563bfefec82ef8354de"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "IntSlider(value=127, description='Y', max=255)"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "b508d6a5cbb64e0793e937b70304a053"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "IntSlider(value=127, description='Z', max=255)"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "c68c67da6dac4f8f9d47dafa3b983b72"
+ }
+ },
+ "metadata": {}
+ }
+ ],
+ "source": [
+ "images = peek_class_new(nifi_image,1)\n",
+ "images.plots()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {
+ "id": "MMgLXmvT4gUS"
+ },
+ "outputs": [],
+ "source": [
+ "img= nib.load('data/Mindboggle_101/MMRR-21_volumes/MMRR-21-7/GWlabels.DKT31.manual+aseg.nii.gz').get_fdata(dtype=np.float32)\n",
+ "volume_shape = [256, 256, 256]\n",
+ "temp= np.zeros(volume_shape)\n",
+ "temp[: img.shape[0], : img.shape[1], : img.shape[2]] = img\n",
+ "image=temp\n",
+ "# Create a NIfTI image object\n",
+ "nifi_image = nib.Nifti1Image(image, affine=np.eye(4)) # Use identity affine matrix for simplicity"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {
+ "id": "I9Ac24kj4ls2",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 575,
+ "referenced_widgets": [
+ "b2fac97873cf4089bbea4ad4b7e89a16",
+ "fbfa6a0ba2f645e5835009972e37cef1",
+ "bc93ecff405b48a3afa11ec5f6a278e4",
+ "b7aa97fa10c3462f953eb5255d47499c",
+ "d372bb6414aa4b84bd3c170d411d167a",
+ "4b1649fe17b243198b85d517737978e4",
+ "a327d96399154967a7a426bf40743f56",
+ "68bf878904674acbb8c3885d915513c4",
+ "9772174186944909a070f85559ca552b",
+ "eeaf6a61f280437da3431ea019049275",
+ "54a89a522ded47b1806d52b0b6c8ece6",
+ "90e6aece08594ddf87be7dcac7805b44",
+ "9b9a4b2079d24f93bde9e57a9e6b3661"
+ ]
+ },
+ "outputId": "413ba204-1f13-4daf-97dc-ef5f39112126"
+ },
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "interactive(children=(IntSlider(value=127, description='X', max=255), IntSlider(value=127, description='Y', ma…"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "b2fac97873cf4089bbea4ad4b7e89a16"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "IntSlider(value=127, description='X', max=255)"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "fbfa6a0ba2f645e5835009972e37cef1"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "IntSlider(value=127, description='Y', max=255)"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "bc93ecff405b48a3afa11ec5f6a278e4"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "IntSlider(value=127, description='Z', max=255)"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "b7aa97fa10c3462f953eb5255d47499c"
+ }
+ },
+ "metadata": {}
+ }
+ ],
+ "source": [
+ "labels = peek_class_new(nifi_image,1)\n",
+ "labels.plots()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "OvlSnfs38UAS"
+ },
+ "source": [
+ "# Meshnet custom model Implementation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {
+ "id": "hDa9g4K29nEh"
+ },
+ "outputs": [],
+ "source": [
+ "MeshNet_5_ae16 = [\n",
+ " {\"in_channels\": -1,\"kernel_size\": 3,\"out_channels\": 5,\"padding\": 1,\"stride\": 1,\"dilation\": 1,},\n",
+ " {\"in_channels\": 5,\"kernel_size\": 3,\"out_channels\": 5,\"padding\": 2,\"stride\": 1,\"dilation\": 2,},\n",
+ " {\"in_channels\": 5,\"kernel_size\": 3,\"out_channels\": 5,\"padding\": 4,\"stride\": 1,\"dilation\": 4,},\n",
+ " {\"in_channels\": 5,\"kernel_size\": 3,\"out_channels\": 5,\"padding\": 8,\"stride\": 1,\"dilation\": 8,},\n",
+ " {\"in_channels\": 5,\"kernel_size\": 3,\"out_channels\": 5,\"padding\": 16,\"stride\": 1,\"dilation\": 16,},\n",
+ " {\"in_channels\": 5,\"kernel_size\": 3,\"out_channels\": 5,\"padding\": 8,\"stride\": 1,\"dilation\": 8,},\n",
+ " {\"in_channels\": 5,\"kernel_size\": 3,\"out_channels\": 5,\"padding\": 4,\"stride\": 1,\"dilation\": 4,},\n",
+ " {\"in_channels\": 5,\"kernel_size\": 3,\"out_channels\": 5,\"padding\": 2,\"stride\": 1,\"dilation\": 2,},\n",
+ " {\"in_channels\": 5,\"kernel_size\": 3,\"out_channels\": 5,\"padding\": 1,\"stride\": 1,\"dilation\": 1,},\n",
+ " {\"in_channels\": 5,\"kernel_size\": 1,\"out_channels\": -1,\"padding\": 0,\"stride\": 1,\"dilation\": 1,},\n",
+ "]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {
+ "id": "DawF_iJk-ujc"
+ },
+ "outputs": [],
+ "source": [
+ "def ae16channels(channels=5, basearch=MeshNet_5_ae16):\n",
+ " start = {\"out_channels\": channels}\n",
+ " middle = {\"in_channels\": channels,\"out_channels\": channels}\n",
+ " end = {\"in_channels\": channels}\n",
+ " modifier = [start] + [middle for _ in range(len(basearch)-2)] + [end]\n",
+ " newarch = basearch.copy()\n",
+ " [x.update(y) for x,y in zip(newarch, modifier)]\n",
+ " return newarch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {
+ "id": "JkslLfm3-3YI"
+ },
+ "outputs": [],
+ "source": [
+ "def conv_w_bn_before_act(dropout_p=0, bnorm=True, gelu=False, *args, **kwargs):\n",
+ " \"\"\"Configurable Conv block with Batchnorm and Dropout\"\"\"\n",
+ " sequence = [(\"conv\", nn.Conv3d(*args, **kwargs))]\n",
+ " if bnorm:\n",
+ " sequence.append((\"bnorm\", nn.BatchNorm3d(kwargs[\"out_channels\"])))\n",
+ " if gelu:\n",
+ " sequence.append((\"gelu\", nn.GELU()))\n",
+ " else:\n",
+ " sequence.append((\"relu\", nn.ReLU(inplace=True)))\n",
+ " sequence.append((\"dropout\", nn.Dropout3d(dropout_p)))\n",
+ " layer = nn.Sequential(OrderedDict(sequence))\n",
+ " return layer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {
+ "id": "lYYzQRaB-6O3"
+ },
+ "outputs": [],
+ "source": [
+ "def init_weights(model):\n",
+ " \"\"\"Set weights to be xavier normal for all Convs\"\"\"\n",
+ " for m in model.modules():\n",
+ " if isinstance(m, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d)):\n",
+ " nn.init.xavier_normal_(m.weight, gain=nn.init.calculate_gain(\"relu\"))\n",
+ " nn.init.constant_(m.bias, 0.0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {
+ "id": "u4Y1-3qa-9ii"
+ },
+ "outputs": [],
+ "source": [
+ "class MeshNet(nn.Module):\n",
+ " \"\"\"Configurable MeshNet from https://arxiv.org/pdf/1612.00940.pdf\"\"\"\n",
+ "\n",
+ " def __init__(self, n_channels, n_classes, large=True, bnorm=True, gelu=False, dropout_p=0):\n",
+ " \"\"\"Init\"\"\"\n",
+ " if large:\n",
+ " params = ae16channels(5)\n",
+ " else:\n",
+ " params = MeshNet_5_ae16\n",
+ "\n",
+ " super(MeshNet, self).__init__()\n",
+ " params[0][\"in_channels\"] = n_channels\n",
+ " params[-1][\"out_channels\"] = n_classes\n",
+ " layers = [\n",
+ " conv_w_bn_before_act(dropout_p=dropout_p, bnorm=bnorm, gelu=gelu, **block_kwargs)\n",
+ " for block_kwargs in params[:-1]\n",
+ " ]\n",
+ " layers.append(nn.Conv3d(**params[-1]))\n",
+ " self.model = nn.Sequential(*layers)\n",
+ " init_weights(self.model)\n",
+ "\n",
+ " def forward(self, x):\n",
+ " \"\"\"Forward pass\"\"\"\n",
+ " x = self.model(x)\n",
+ " return x"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {
+ "id": "T6pC_IrX_B8q"
+ },
+ "outputs": [],
+ "source": [
+ "class enMesh_checkpoint(MeshNet):\n",
+ " def train_forward(self, x):\n",
+ " y = x\n",
+ " y.requires_grad_()\n",
+ " y = checkpoint_sequential(\n",
+ " self.model, len(self.model), y, preserve_rng_state=False\n",
+ " )\n",
+ " return y\n",
+ "\n",
+ " def eval_forward(self, x):\n",
+ " \"\"\"Forward pass\"\"\"\n",
+ " self.model.eval()\n",
+ " with torch.inference_mode():\n",
+ " x = self.model(x)\n",
+ " return x\n",
+ "\n",
+ " def forward(self, x):\n",
+ " if self.training:\n",
+ " return self.train_forward(x)\n",
+ " else:\n",
+ " return self.eval_forward(x)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "UuM8jtHKQOfz"
+ },
+ "source": [
+ "# Subvolumes generator/Volume reassembler and dataloader funtions"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Below class **CubeDivider** provides a convenient way to divide and reassemble tensors, which can be useful for processing large 3D tensors. For more information please refer to the blog : https://medium.com/pytorch/catalyst-neuro-a-3d-brain-segmentation-pipeline-for-mri-b1bb1109276a"
+ ],
+ "metadata": {
+ "id": "aNH5hc_dAVii"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {
+ "id": "WesSBrDr3H5x",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "d5bfb55c-2658-432e-8214-043b0750e170"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "torch.Size([32, 32, 32])\n"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "from torch.utils.data import Dataset\n",
+ "\n",
+ "class CubeDivider:\n",
+ " def __init__(self, tensor, num_cubes):\n",
+ " self.tensor = tensor\n",
+ " self.num_cubes = num_cubes\n",
+ " self.sub_cube_size = tensor.shape[0] // num_cubes # Assuming the tensor is a cube\n",
+ "\n",
+ " def divide_into_sub_cubes(self):\n",
+ " sub_cubes = []\n",
+ "\n",
+ " for i in range(self.num_cubes):\n",
+ " for j in range(self.num_cubes):\n",
+ " for k in range(self.num_cubes):\n",
+ " sub_cube = self.tensor[\n",
+ " i * self.sub_cube_size: (i + 1) * self.sub_cube_size,\n",
+ " j * self.sub_cube_size: (j + 1) * self.sub_cube_size,\n",
+ " k * self.sub_cube_size: (k + 1) * self.sub_cube_size\n",
+ " ].clone()\n",
+ " sub_cubes.append(sub_cube)\n",
+ "\n",
+ " sub_cubes = torch.stack(sub_cubes,0)\n",
+ " return sub_cubes\n",
+ "\n",
+ " @staticmethod\n",
+ " def reassemble_sub_cubes(sub_cubes):\n",
+ " sub_cubes = torch.unbind(sub_cubes, dim=0)\n",
+ " num_cubes = int(len(sub_cubes) ** (1/3))\n",
+ " sub_cube_size = sub_cubes[0].shape[0]\n",
+ " tensor_size = num_cubes * sub_cube_size\n",
+ " tensor = torch.zeros((tensor_size, tensor_size, tensor_size), dtype=torch.float32)\n",
+ "\n",
+ " for i in range(num_cubes):\n",
+ " for j in range(num_cubes):\n",
+ " for k in range(num_cubes):\n",
+ " sub_cube = sub_cubes[i * num_cubes**2 + j * num_cubes + k]\n",
+ " tensor[\n",
+ " i * sub_cube_size: (i + 1) * sub_cube_size,\n",
+ " j * sub_cube_size: (j + 1) * sub_cube_size,\n",
+ " k * sub_cube_size: (k + 1) * sub_cube_size\n",
+ " ] = sub_cube\n",
+ "\n",
+ " return tensor\n",
+ "\n",
+ "# Usage:\n",
+ "# Assuming tensor is a 3D PyTorch tensor\n",
+ "tensor = torch.randn(32, 32, 32) # Example tensor\n",
+ "num_cubes = 2 # Number of sub-cubes\n",
+ "\n",
+ "divider = CubeDivider(tensor, num_cubes)\n",
+ "\n",
+ "# Divide the cube tensor into sub-cubes\n",
+ "sub_cubes = divider.divide_into_sub_cubes()\n",
+ "\n",
+ "# Reassemble the sub-cubes to create the original cube tensor\n",
+ "reconstructed_tensor = CubeDivider.reassemble_sub_cubes(sub_cubes)\n",
+ "\n",
+ "print(reconstructed_tensor.shape) # Should be the same as the original tensor shape\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Below class **DataLoaderClass** that loads and processes data from a CSV file using PyTorch's DataLoader.\n",
+ "\n",
+ "1. The dataloader method reads the CSV file, preprocesses the images and labels, and creates a DataLoader object for the processed data.\n",
+ "2. The data is divided into sub-cubes using the **CubeDivider** class.\n",
+ "3. The labels are converted into a one-hot encoding representation."
+ ],
+ "metadata": {
+ "id": "Vc7TzLElCXhS"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {
+ "id": "cT9N0M3q1ymK"
+ },
+ "outputs": [],
+ "source": [
+ "class DataLoaderClass:\n",
+ " def __init__(self,csv_file, coor_factor, batch_size):\n",
+ " self.csv_file=csv_file\n",
+ " self.coor_factor=coor_factor\n",
+ " self.batch_size=batch_size\n",
+ "\n",
+ " def dataloader(self):\n",
+ " data = pd.read_csv(self.csv_file)\n",
+ " volume_shape = [256, 256, 256]\n",
+ " images =()\n",
+ " labels=()\n",
+ " for image,label in zip(data['images'],data['GWlabels']):\n",
+ "\n",
+ " img = nib.load('./'+image)\n",
+ " img = img.get_fdata()\n",
+ " temp= np.zeros(volume_shape)\n",
+ " temp[: img.shape[0], : img.shape[1], : img.shape[2]] = img\n",
+ " temp = np.array(temp)\n",
+ " image_data = (temp - temp.mean()) / temp.std()\n",
+ " sub_temp = CubeDivider(torch.tensor(image_data),self.coor_factor)\n",
+ " images = images+(sub_temp.divide_into_sub_cubes(),)\n",
+ "\n",
+ " lab = nib.load('./'+label)\n",
+ " lab = lab.get_fdata()\n",
+ " temp= np.zeros(volume_shape)\n",
+ " temp[: lab.shape[0], : lab.shape[1], : lab.shape[2]] = lab\n",
+ " temp = np.array(temp)\n",
+ " sub_temp = CubeDivider(torch.tensor(temp),self.coor_factor)\n",
+ " labels = labels+(sub_temp.divide_into_sub_cubes(),)\n",
+ "\n",
+ " images = torch.stack(images)\n",
+ " labels = torch.stack(labels)\n",
+ " images = images.reshape(-1,1,int(volume_shape[0]/self.coor_factor),int(volume_shape[1]/self.coor_factor),int(volume_shape[2]/self.coor_factor)).float()\n",
+ " labels = labels.reshape(-1,1,int(volume_shape[0]/self.coor_factor),int(volume_shape[1]/self.coor_factor),int(volume_shape[2]/self.coor_factor))\n",
+ " new_labels = ()\n",
+ " for temp in labels:\n",
+ " new_temp = ()\n",
+ " for i in [0,1,2]:\n",
+ " new_temp=new_temp+ (torch.mul(torch.tensor(np.asarray(temp == i, dtype=np.float64)),1),)\n",
+ " new_temp = torch.stack(new_temp)\n",
+ " new_labels = new_labels + (new_temp,)\n",
+ " labels = torch.stack(new_labels)\n",
+ " labels = labels.reshape(-1,3,int(volume_shape[0]/self.coor_factor),int(volume_shape[1]/self.coor_factor),int(volume_shape[2]/self.coor_factor))\n",
+ " dataset = torch.utils.data.TensorDataset(images, labels)\n",
+ " return DataLoader(dataset, batch_size=self.batch_size, shuffle=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "cuUNlPH_qWeG"
+ },
+ "source": [
+ "#Pytorch training"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "A common metric for assessing the similarity between two sets or segmentations is dice score. With a value ranging from 0 to 1, it calculates the amount of overlap between the predicted and ground truth labels. An exact overlap is represented by a Dice score of 1, while no overlap is represented by a score of 0. By using the fudge factor, division by zero is avoided."
+ ],
+ "metadata": {
+ "id": "n1MB5aTwE26e"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {
+ "id": "xu4EeuURrNlt"
+ },
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import torch\n",
+ "\n",
+ "\n",
+ "def faster_dice(x, y, labels, fudge_factor=1e-8):\n",
+ " \"\"\"Faster PyTorch implementation of Dice scores.\n",
+ " :param x: input label map as torch.Tensor\n",
+ " :param y: input label map as torch.Tensor of the same size as x\n",
+ " :param labels: list of labels to evaluate on\n",
+ " :param fudge_factor: an epsilon value to avoid division by zero\n",
+ " :return: pytorch Tensor with Dice scores in the same order as labels.\n",
+ " \"\"\"\n",
+ "\n",
+ " assert x.shape == y.shape, \"both inputs should have same size, had {} and {}\".format(\n",
+ " x.shape, y.shape\n",
+ " )\n",
+ "\n",
+ " if len(labels) > 1:\n",
+ "\n",
+ " dice_score = torch.zeros(len(labels))\n",
+ " for label in labels:\n",
+ " x_label = x == label\n",
+ " y_label = y == label\n",
+ " xy_label = (x_label & y_label).sum()\n",
+ " dice_score[label] = (\n",
+ " 2 * xy_label / (x_label.sum() + y_label.sum() + fudge_factor)\n",
+ " )\n",
+ "\n",
+ " else:\n",
+ " dice_score = dice(x == labels[0], y == labels[0], fudge_factor=fudge_factor)\n",
+ "\n",
+ " return dice_score\n",
+ "\n",
+ "\n",
+ "def dice(x, y, fudge_factor=1e-8):\n",
+ " \"\"\"Implementation of dice scores for 0/1 numy array\"\"\"\n",
+ " return 2 * torch.sum(x * y) / (torch.sum(x) + torch.sum(y) + fudge_factor)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "* The trainer class is responsible for training a neural network model for image segmentation.\n",
+ "* It takes parameters such as the number of input channels, number of output classes, data loaders for training and validation, subvolume shape, number of epochs, path to the model checkpoint file, and learning rate.\n",
+ "* The constructor initializes the model, criterion (**CrossEntropyLoss**), optimizer (**RMSprop**), and other class variables.\n",
+ "* The train method trains the model for the specified number of epochs.\n",
+ "Within each epoch, it iterates over the training data, computes the loss and dice scores, performs backpropagation, and updates the model's parameters.\n",
+ "* After the training loop, it evaluates the model on the validation data, computes the loss and dice scores, and prints the training and validation metrics for each epoch.\n",
+ "* The **faster_dice** function is used to calculate the dice scores."
+ ],
+ "metadata": {
+ "id": "XsQs-uPFFO9j"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {
+ "id": "G7I-kcMMqWL4"
+ },
+ "outputs": [],
+ "source": [
+ "from torch.nn import functional as F\n",
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
+ "class trainer:\n",
+ " def __init__(self,n_channels, n_classes, trainloader, valloader, subvol_shape, epoches,modelpth,lrate=0.0007):\n",
+ " self.n_channels = n_channels # Number of input channels\n",
+ " self.n_classes = n_classes # Number of output classes\n",
+ " self.model = enMesh_checkpoint(self.n_channels, self.n_classes).to(device, dtype=torch.float32)\n",
+ " self.criterion = nn.CrossEntropyLoss()\n",
+ " self.lrate = lrate\n",
+ " self.trainloader = trainloader\n",
+ " self.valloader = valloader\n",
+ " self.subvol_shape = subvol_shape\n",
+ " self.epoches = epoches\n",
+ " self.modelpth = modelpth\n",
+ " self.optimizer = torch.optim.RMSprop(self.model.parameters(), lr=self.lrate)\n",
+ "\n",
+ "\n",
+ " def train(self, num_epoches):\n",
+ " try:\n",
+ " self.model.load_state_dict(torch.load(self.modelpth))\n",
+ " except:\n",
+ " print('No valid pretained model.pth file mentioned')\n",
+ " epoch =0\n",
+ " train_loss = 0.0\n",
+ " train_dice = 0.0\n",
+ " val_loss = 0.0\n",
+ " val_dice = 0.0\n",
+ " while epoch != num_epoches :\n",
+ "\n",
+ " self.model.train()\n",
+ " train_loss = 0.0\n",
+ " for images, labels in self.trainloader:\n",
+ " if 1 in torch.argmax(torch.squeeze(labels),0) or 2 in torch.argmax(torch.squeeze(labels),0):\n",
+ " images = images.to(device, dtype=torch.float32)\n",
+ " labels = labels.to(device, dtype=torch.float32)\n",
+ " train_dice = 0.0\n",
+ " self.optimizer.zero_grad()\n",
+ " outputs = self.model(images)\n",
+ " loss=self.criterion(outputs, labels)\n",
+ " train_loss += loss.item()\n",
+ " dice_scores = faster_dice(torch.argmax(torch.squeeze(outputs),0), torch.argmax(torch.squeeze(labels),0), labels=[0, 1, 2]) # Specify the labels to evaluate on\n",
+ " train_dice += dice_scores.mean().item() # Take the mean Dice score\n",
+ " loss = loss+ (1-dice_scores.mean().item())\n",
+ " loss.backward()\n",
+ " self.optimizer.step()\n",
+ "\n",
+ " self.model.eval()\n",
+ " val_loss = 0.0\n",
+ " val_dice = 0.0\n",
+ " with torch.no_grad():\n",
+ " for images, labels in self.valloader:\n",
+ " images = images.to(device, dtype=torch.float32)\n",
+ " labels = labels.to(device, dtype=torch.float32)\n",
+ " outputs = self.model(images)\n",
+ " loss = self.criterion(outputs, labels)\n",
+ " val_loss += loss.item()\n",
+ " dice_scores = faster_dice(torch.argmax(torch.squeeze(outputs),0), torch.argmax(torch.squeeze(labels),0), labels=[0, 1, 2])\n",
+ " val_dice += dice_scores.mean().item()\n",
+ "\n",
+ "\n",
+ " train_loss /= len(self.trainloader)\n",
+ " train_dice /= len(self.trainloader)\n",
+ " val_loss /= len(self.valloader)\n",
+ " val_dice /= len(self.valloader)\n",
+ "\n",
+ " print(f\"Epoch {epoch+1} - Train Loss: {train_loss:.4f} - Train Dice: {train_dice:.4f} - Val Loss: {val_loss:.4f} - Val Dice: {val_dice:.4f}\")\n",
+ " epoch = epoch+1"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "#Training\n"
+ ],
+ "metadata": {
+ "id": "XdahCxTd2ik3"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "\n",
+ "**Cumulative learning** - Training in pytorch 4 separate cycles depending upon how images are divided into number of small subvolumes."
+ ],
+ "metadata": {
+ "id": "rHl965bfShJo"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Training Cycle 1 : subvolume (256,256,256)"
+ ],
+ "metadata": {
+ "id": "83o2ydQ690xt"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "metadata": {
+ "id": "xZbLNvZFHCxh",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "8ada23d4-21dd-4a8d-be62-f45f2d6af027"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "No valid pretained model.pth file mentioned\n",
+ "Epoch 1 - Train Loss: 0.9851 - Train Dice: 0.0381 - Val Loss: 0.7933 - Val Dice: 0.3471\n",
+ "Epoch 2 - Train Loss: 0.6630 - Train Dice: 0.0321 - Val Loss: 0.6158 - Val Dice: 0.3251\n",
+ "Epoch 3 - Train Loss: 0.5784 - Train Dice: 0.0321 - Val Loss: 0.5654 - Val Dice: 0.3248\n",
+ "Epoch 4 - Train Loss: 0.5324 - Train Dice: 0.0320 - Val Loss: 0.5203 - Val Dice: 0.3248\n",
+ "Epoch 5 - Train Loss: 0.5012 - Train Dice: 0.0321 - Val Loss: 0.4618 - Val Dice: 0.3248\n",
+ "Epoch 6 - Train Loss: 0.4746 - Train Dice: 0.0319 - Val Loss: 0.4433 - Val Dice: 0.3248\n",
+ "Epoch 7 - Train Loss: 0.4529 - Train Dice: 0.0322 - Val Loss: 0.4214 - Val Dice: 0.3248\n",
+ "Epoch 8 - Train Loss: 0.4355 - Train Dice: 0.0320 - Val Loss: 0.4150 - Val Dice: 0.3248\n",
+ "Epoch 9 - Train Loss: 0.4193 - Train Dice: 0.0322 - Val Loss: 0.3806 - Val Dice: 0.3248\n",
+ "Epoch 10 - Train Loss: 0.4045 - Train Dice: 0.0322 - Val Loss: 0.3737 - Val Dice: 0.3248\n",
+ "Epoch 11 - Train Loss: 0.3914 - Train Dice: 0.0322 - Val Loss: 0.3493 - Val Dice: 0.3248\n",
+ "Epoch 12 - Train Loss: 0.3794 - Train Dice: 0.0322 - Val Loss: 0.3471 - Val Dice: 0.3248\n",
+ "Epoch 13 - Train Loss: 0.3672 - Train Dice: 0.0321 - Val Loss: 0.3299 - Val Dice: 0.3248\n",
+ "Epoch 14 - Train Loss: 0.3561 - Train Dice: 0.0322 - Val Loss: 0.3117 - Val Dice: 0.3248\n",
+ "Epoch 15 - Train Loss: 0.3450 - Train Dice: 0.0323 - Val Loss: 0.3061 - Val Dice: 0.3248\n",
+ "Epoch 16 - Train Loss: 0.3346 - Train Dice: 0.0323 - Val Loss: 0.2999 - Val Dice: 0.3248\n",
+ "Epoch 17 - Train Loss: 0.3213 - Train Dice: 0.0322 - Val Loss: 0.2985 - Val Dice: 0.3248\n",
+ "Epoch 18 - Train Loss: 0.3111 - Train Dice: 0.0320 - Val Loss: 0.2897 - Val Dice: 0.3248\n",
+ "Epoch 19 - Train Loss: 0.2991 - Train Dice: 0.0323 - Val Loss: 0.2747 - Val Dice: 0.3248\n",
+ "Epoch 20 - Train Loss: 0.2894 - Train Dice: 0.0322 - Val Loss: 0.2586 - Val Dice: 0.3248\n"
+ ]
+ }
+ ],
+ "source": [
+ "traindata = DataLoaderClass('./data/dataset_train.csv',1,1).dataloader()\n",
+ "valdata = DataLoaderClass('./data/dataset_valid.csv',1,1).dataloader()\n",
+ "meshnet = trainer(1,3,traindata, valdata, [256,256,256], 20,'',0.0007)\n",
+ "meshnet.train(20)\n",
+ "torch.save(meshnet.model.state_dict(), 'meshnet.pth')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "traindata=''\n",
+ "array=''\n",
+ "i=''\n",
+ "array1=''\n",
+ "images =''\n",
+ "img=''\n",
+ "prediciton=''\n",
+ "predicted=''\n",
+ "temp=''\n",
+ "labels=''\n",
+ "pred_peek=''\n",
+ "volume_shape=''\n",
+ "num_cubes=''\n",
+ "nifi_image=''\n",
+ "criterion=''\n",
+ "valdata=''\n",
+ "loaders=''\n",
+ "divider=''\n",
+ "sub_cubes=''\n",
+ "meshnet=''\n",
+ "model = ''\n",
+ "logdir=''\n",
+ "optimizer=''\n",
+ "scheduler=''\n",
+ "runner=''\n",
+ "tensor=''\n",
+ "reconstructed_tensor=''\n",
+ "import gc\n",
+ "torch.cuda.empty_cache()\n",
+ "gc.collect()"
+ ],
+ "metadata": {
+ "id": "-UPJryQI-Oj_",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "3684e958-cd42-4077-b458-97d3a0e60ab0"
+ },
+ "execution_count": 29,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "0"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 29
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Training Cycle 2 : subvolume (32,32,32)"
+ ],
+ "metadata": {
+ "id": "O9Fu99XO993v"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "traindata = DataLoaderClass('./data/dataset_train.csv',8,1).dataloader()\n",
+ "valdata = DataLoaderClass('./data/dataset_valid.csv',8,1).dataloader()\n",
+ "meshnet = trainer(1,3,traindata, valdata, [32,32,32], 20,'meshnet.pth',0.0007)\n",
+ "meshnet.train(20)\n",
+ "torch.save(meshnet.model.state_dict(), 'meshnet.pth')"
+ ],
+ "metadata": {
+ "id": "ZY4-DYdxtQo7"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "traindata=''\n",
+ "array=''\n",
+ "i=''\n",
+ "array1=''\n",
+ "images =''\n",
+ "img=''\n",
+ "prediciton=''\n",
+ "predicted=''\n",
+ "temp=''\n",
+ "labels=''\n",
+ "pred_peek=''\n",
+ "volume_shape=''\n",
+ "num_cubes=''\n",
+ "nifi_image=''\n",
+ "criterion=''\n",
+ "valdata=''\n",
+ "loaders=''\n",
+ "divider=''\n",
+ "sub_cubes=''\n",
+ "meshnet=''\n",
+ "model = ''\n",
+ "logdir=''\n",
+ "optimizer=''\n",
+ "scheduler=''\n",
+ "runner=''\n",
+ "tensor=''\n",
+ "reconstructed_tensor=''\n",
+ "import gc\n",
+ "torch.cuda.empty_cache()\n",
+ "gc.collect()"
+ ],
+ "metadata": {
+ "id": "buhkzd-1-Nkc"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Training Cycle 3 : subvolume (64,64,64)"
+ ],
+ "metadata": {
+ "id": "helYON-I-Cn6"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "traindata = DataLoaderClass('./data/dataset_train.csv',4,1).dataloader()\n",
+ "valdata = DataLoaderClass('./data/dataset_valid.csv',4,1).dataloader()\n",
+ "meshnet = trainer(1,3,traindata, valdata, [64,64,64], 20,'meshnet.pth',0.0007)\n",
+ "meshnet.train(20)\n",
+ "torch.save(meshnet.model.state_dict(), 'meshnet.pth')"
+ ],
+ "metadata": {
+ "id": "om3Oq3vwxAus"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "traindata=''\n",
+ "array=''\n",
+ "i=''\n",
+ "array1=''\n",
+ "images =''\n",
+ "img=''\n",
+ "prediciton=''\n",
+ "predicted=''\n",
+ "temp=''\n",
+ "labels=''\n",
+ "pred_peek=''\n",
+ "volume_shape=''\n",
+ "num_cubes=''\n",
+ "nifi_image=''\n",
+ "criterion=''\n",
+ "valdata=''\n",
+ "loaders=''\n",
+ "divider=''\n",
+ "sub_cubes=''\n",
+ "meshnet=''\n",
+ "model = ''\n",
+ "logdir=''\n",
+ "optimizer=''\n",
+ "scheduler=''\n",
+ "runner=''\n",
+ "tensor=''\n",
+ "reconstructed_tensor=''\n",
+ "import gc\n",
+ "torch.cuda.empty_cache()\n",
+ "gc.collect()"
+ ],
+ "metadata": {
+ "id": "f9arABKL-MjE"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Training Cycle 4 : subvolume (128,128,128)"
+ ],
+ "metadata": {
+ "id": "XHmbvki0-HmK"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "traindata = DataLoaderClass('./data/dataset_train.csv',2,1).dataloader()\n",
+ "valdata = DataLoaderClass('./data/dataset_valid.csv',2,1).dataloader()\n",
+ "meshnet = trainer(1,3,traindata, valdata, [128,128,128], 20,'meshnet.pth',0.0007)\n",
+ "meshnet.train(20)\n",
+ "torch.save(meshnet.model.state_dict(), 'meshnet.pth')"
+ ],
+ "metadata": {
+ "id": "u0BrzwfK5qXA"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "traindata=''\n",
+ "array=''\n",
+ "i=''\n",
+ "array1=''\n",
+ "images =''\n",
+ "img=''\n",
+ "prediciton=''\n",
+ "predicted=''\n",
+ "temp=''\n",
+ "labels=''\n",
+ "pred_peek=''\n",
+ "volume_shape=''\n",
+ "num_cubes=''\n",
+ "nifi_image=''\n",
+ "criterion=''\n",
+ "valdata=''\n",
+ "loaders=''\n",
+ "divider=''\n",
+ "sub_cubes=''\n",
+ "meshnet=''\n",
+ "model = ''\n",
+ "logdir=''\n",
+ "optimizer=''\n",
+ "scheduler=''\n",
+ "runner=''\n",
+ "tensor=''\n",
+ "reconstructed_tensor=''\n",
+ "import gc\n",
+ "torch.cuda.empty_cache()\n",
+ "gc.collect()"
+ ],
+ "metadata": {
+ "id": "M6rVC4HKgkMb"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "#Model evaluation on inference dataset"
+ ],
+ "metadata": {
+ "id": "kAxwlh3M8ULH"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from torch.nn import functional as F\n",
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
+ "class evaluation:\n",
+ " def __init__(self, modelpath, inferloader):\n",
+ " self.inferloader = inferloader\n",
+ " self.modelpath = modelpath\n",
+ " self.model = enMesh_checkpoint(1, 3).to(device, dtype=torch.float32)\n",
+ " self.criterion = nn.CrossEntropyLoss()\n",
+ "\n",
+ " def eval(self):\n",
+ " try:\n",
+ " self.model.load_state_dict(torch.load(self.modelpath))\n",
+ " self.model.eval()\n",
+ " infer_loss = 0.0\n",
+ " infer_dice = 0.0\n",
+ " with torch.no_grad():\n",
+ " for images, labels in self.inferloader:\n",
+ " images = images.to(device, dtype=torch.float32)\n",
+ " labels = labels.to(device, dtype=torch.float32)\n",
+ " outputs = self.model(images)\n",
+ " loss = self.criterion(outputs, labels)\n",
+ " infer_loss += loss.item()\n",
+ " dice_scores = faster_dice(torch.argmax(torch.squeeze(outputs),0), torch.argmax(torch.squeeze(labels),0), labels=[0, 1, 2])\n",
+ " infer_dice += dice_scores.mean().item()\n",
+ " infer_loss /= len(self.inferloader)\n",
+ " infer_dice /= len(self.inferloader)\n",
+ " print('Loss :',infer_loss,' Dice :',infer_dice)\n",
+ " except Exception as e:\n",
+ " print('No valid pretained model.pth file mentioned',e)\n"
+ ],
+ "metadata": {
+ "id": "N5WC0gqm3mhm"
+ },
+ "execution_count": 30,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "inferdata = DataLoaderClass('./data/dataset_infer.csv',1,1).dataloader()\n",
+ "modeval=evaluation('meshnet.pth',inferdata)\n",
+ "modeval.eval()"
+ ],
+ "metadata": {
+ "id": "sTVsT5gQ8Gpu",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "d32ec6cb-da65-41e1-860d-547b029dde83"
+ },
+ "execution_count": 31,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Loss : 0.2616010407606761 Dice : 0.3222907781600952\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "224UKySu240T"
+ },
+ "source": [
+ "#Plots"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "metadata": {
+ "id": "F5eVPaTZ27iQ"
+ },
+ "outputs": [],
+ "source": [
+ "inferdata = DataLoaderClass('./data/dataset_infer.csv',1,1).dataloader()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "GroundTruth Label"
+ ],
+ "metadata": {
+ "id": "dd8nh9plNrK_"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "metadata": {
+ "id": "-fLzOl4r2-Qe",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 575,
+ "referenced_widgets": [
+ "e610567fb36246cd8a82ca6548853632",
+ "f7ae22f555734bc88e1fe06c0f853592",
+ "bb003207aea44fddbe696ee9b24472bc",
+ "253608ba694444e99ca331b7472d2533",
+ "26ea0fed3bd34f4d96024b39812c5d70",
+ "593dfac8552c47ed83ac9e626eac3f62",
+ "1681e259474a40ecbebfad76fa255d3a",
+ "ad5469fed7464799a1d1b84bcedcf7b4",
+ "5c42684be0e14b0e82e17b4fda8e6da2",
+ "c9615b06b1dd4209931c966ac5a869b6",
+ "47e8f90022a2486e88c9991fb731e6aa",
+ "20aa513df93b4e3896d710fd3982bdad",
+ "508735fc63844aaf8e5c0722ae38d8f1"
+ ]
+ },
+ "outputId": "6dcd04da-dd31-4c0a-df64-e459cd4e1f4e"
+ },
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "interactive(children=(IntSlider(value=127, description='X', max=255), IntSlider(value=127, description='Y', ma…"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "e610567fb36246cd8a82ca6548853632"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "IntSlider(value=127, description='X', max=255)"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "f7ae22f555734bc88e1fe06c0f853592"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "IntSlider(value=127, description='Y', max=255)"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "bb003207aea44fddbe696ee9b24472bc"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "IntSlider(value=127, description='Z', max=255)"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "253608ba694444e99ca331b7472d2533"
+ }
+ },
+ "metadata": {}
+ }
+ ],
+ "source": [
+ "prediciton = inferdata.dataset.tensors[1][0].reshape(-1,3,256,256,256)\n",
+ "predicted = torch.argmax(torch.squeeze(prediciton),0)\n",
+ "prediciton = predicted.reshape(256,256,256).numpy()\n",
+ "array1 = prediciton.astype(np.uint16)\n",
+ "nifi_image = nib.Nifti1Image(array1, affine=np.eye(4)) # Use identity affine matrix for simplicity\n",
+ "pred_peek= peek_class_new(nifi_image,1)\n",
+ "pred_peek.plots()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Label predicted"
+ ],
+ "metadata": {
+ "id": "4vLKF3FQNvUV"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "metadata": {
+ "id": "fW0dHliJ3Dn_",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "f796dc03-86b9-4b94-95b7-890156fb1fc9"
+ },
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "enMesh_checkpoint(\n",
+ " (model): Sequential(\n",
+ " (0): Sequential(\n",
+ " (conv): Conv3d(1, 5, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n",
+ " (bnorm): BatchNorm3d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " (dropout): Dropout3d(p=0, inplace=False)\n",
+ " )\n",
+ " (1): Sequential(\n",
+ " (conv): Conv3d(5, 5, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(2, 2, 2), dilation=(2, 2, 2))\n",
+ " (bnorm): BatchNorm3d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " (dropout): Dropout3d(p=0, inplace=False)\n",
+ " )\n",
+ " (2): Sequential(\n",
+ " (conv): Conv3d(5, 5, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(4, 4, 4), dilation=(4, 4, 4))\n",
+ " (bnorm): BatchNorm3d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " (dropout): Dropout3d(p=0, inplace=False)\n",
+ " )\n",
+ " (3): Sequential(\n",
+ " (conv): Conv3d(5, 5, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(8, 8, 8), dilation=(8, 8, 8))\n",
+ " (bnorm): BatchNorm3d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " (dropout): Dropout3d(p=0, inplace=False)\n",
+ " )\n",
+ " (4): Sequential(\n",
+ " (conv): Conv3d(5, 5, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(16, 16, 16), dilation=(16, 16, 16))\n",
+ " (bnorm): BatchNorm3d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " (dropout): Dropout3d(p=0, inplace=False)\n",
+ " )\n",
+ " (5): Sequential(\n",
+ " (conv): Conv3d(5, 5, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(8, 8, 8), dilation=(8, 8, 8))\n",
+ " (bnorm): BatchNorm3d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " (dropout): Dropout3d(p=0, inplace=False)\n",
+ " )\n",
+ " (6): Sequential(\n",
+ " (conv): Conv3d(5, 5, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(4, 4, 4), dilation=(4, 4, 4))\n",
+ " (bnorm): BatchNorm3d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " (dropout): Dropout3d(p=0, inplace=False)\n",
+ " )\n",
+ " (7): Sequential(\n",
+ " (conv): Conv3d(5, 5, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(2, 2, 2), dilation=(2, 2, 2))\n",
+ " (bnorm): BatchNorm3d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " (dropout): Dropout3d(p=0, inplace=False)\n",
+ " )\n",
+ " (8): Sequential(\n",
+ " (conv): Conv3d(5, 5, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n",
+ " (bnorm): BatchNorm3d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (relu): ReLU(inplace=True)\n",
+ " (dropout): Dropout3d(p=0, inplace=False)\n",
+ " )\n",
+ " (9): Conv3d(5, 3, kernel_size=(1, 1, 1), stride=(1, 1, 1))\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 34
+ }
+ ],
+ "source": [
+ "model = enMesh_checkpoint(n_channels=1, n_classes=3)\n",
+ "model.load_state_dict(\n",
+ " torch.load(\"meshnet.pth\")\n",
+ ")\n",
+ "model.eval()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "prediciton = model(inferdata.dataset.tensors[0][0].reshape(-1,1,256,256,256))\n",
+ "predicted = torch.argmax(torch.squeeze(prediciton),0)\n",
+ "prediciton = predicted.reshape(256,256,256).numpy()\n",
+ "array = prediciton.astype(np.uint16)\n",
+ "nifi_image = nib.Nifti1Image(array, affine=np.eye(4)) # Use identity affine matrix for simplicity\n",
+ "pred_peek= peek_class_new(nifi_image,1)\n",
+ "pred_peek.plots()"
+ ],
+ "metadata": {
+ "id": "tSOoQD6EkR-U"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Final notes**"
+ ],
+ "metadata": {
+ "id": "Q2JEHsWhZvgk"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "This tutorial aims to provide a simple example of how to train the MeshNet model. However, it is worth noting that the actual brainchop models used in the tool have high accuracy. Therefore, thousands of MRI scans are needed during the training phase to achieve this level of accuracy. Unfortunately, the current capacity of Google Colab is insufficient to handle such a large dataset, and a cluster is required for the training process."
+ ],
+ "metadata": {
+ "id": "q_yFRL5BZ22b"
+ }
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "provenance": [],
+ "machine_shape": "hm",
+ "gpuType": "T4"
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "777ef069c28b46b19dfc96a948400137": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "VBoxModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [
+ "widget-interact"
+ ],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "VBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "VBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_87e89b5837144563bfefec82ef8354de",
+ "IPY_MODEL_b508d6a5cbb64e0793e937b70304a053",
+ "IPY_MODEL_c68c67da6dac4f8f9d47dafa3b983b72",
+ "IPY_MODEL_61fb6b1b34434f5ea077e588ac883f83"
+ ],
+ "layout": "IPY_MODEL_367074e77fcf406688d1b2b6022e30f4"
+ }
+ },
+ "87e89b5837144563bfefec82ef8354de": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "IntSliderModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "IntSliderModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "IntSliderView",
+ "continuous_update": true,
+ "description": "X",
+ "description_tooltip": null,
+ "disabled": false,
+ "layout": "IPY_MODEL_d087bd6e8ec54991b1894a61c6a4cd6c",
+ "max": 255,
+ "min": 0,
+ "orientation": "horizontal",
+ "readout": true,
+ "readout_format": "d",
+ "step": 1,
+ "style": "IPY_MODEL_84931741e8604604add005389f0c5370",
+ "value": 127
+ }
+ },
+ "b508d6a5cbb64e0793e937b70304a053": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "IntSliderModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "IntSliderModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "IntSliderView",
+ "continuous_update": true,
+ "description": "Y",
+ "description_tooltip": null,
+ "disabled": false,
+ "layout": "IPY_MODEL_653f4b81298c40248cb284378e2b1ba0",
+ "max": 255,
+ "min": 0,
+ "orientation": "horizontal",
+ "readout": true,
+ "readout_format": "d",
+ "step": 1,
+ "style": "IPY_MODEL_1c49223d5d3f415ebdc16ca30a5cc0b5",
+ "value": 127
+ }
+ },
+ "c68c67da6dac4f8f9d47dafa3b983b72": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "IntSliderModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "IntSliderModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "IntSliderView",
+ "continuous_update": true,
+ "description": "Z",
+ "description_tooltip": null,
+ "disabled": false,
+ "layout": "IPY_MODEL_a3a2daae81564f3194451604ff6e5de6",
+ "max": 255,
+ "min": 0,
+ "orientation": "horizontal",
+ "readout": true,
+ "readout_format": "d",
+ "step": 1,
+ "style": "IPY_MODEL_dd978f745221426c9a87421b4c295ff6",
+ "value": 127
+ }
+ },
+ "61fb6b1b34434f5ea077e588ac883f83": {
+ "model_module": "@jupyter-widgets/output",
+ "model_name": "OutputModel",
+ "model_module_version": "1.0.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/output",
+ "_model_module_version": "1.0.0",
+ "_model_name": "OutputModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/output",
+ "_view_module_version": "1.0.0",
+ "_view_name": "OutputView",
+ "layout": "IPY_MODEL_ec79da9919e94249bdd079f55fd553ed",
+ "msg_id": "",
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": "