Skip to content

Machine Learning for Medical Image Processing. Project done for Charité, the university hospital of Berlin. The aim was to segment coronary arteries and then extract a graph from it, in order to aid detection of coronary artery disease (CAD).

Notifications You must be signed in to change notification settings

emile-gaudinot/ML4MIP

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ML4MIP: Machine Learning for Medical Image Processing

Overview

ML4MIP focuses on employing deep learning models such as UNETR, UNet, MedSAM, and nnU-Net to segment coronary arteries in CTA images, and then extract graph-based representations of coronary structures. This effort aims to enhance the early detection of coronary artery disease (CAD) - the most common cause of mortality in developed countries - as well as the treatment planning for the patients. CAD is characterized by the narrowing of coronary vessels due to plaque buildup, leading to reduced blood flow and serious health events.

ML4MIP is a machine learning framework, including training, validation, inference, preprocessing, postprocessing, and graph extraction. It utilizes the Hydra configuration manager for flexible configuration handling.

Features

  • Training: Train deep learning models on medical imaging datasets.
  • Validation: Evaluate trained models using validation datasets.
  • Preprocessing: Prepare datasets for training and inference.
  • Inference: Run inference on new medical images.
  • Graph Extraction: Extract graph representations from medical images.
  • Postprocessing: Apply filtering and cleanup operations on model outputs.
  • MLFlow Integration: Log experiments, model parameters, and metrics.

Installation

ML4MIP uses pyproject.toml for dependency management. To install the project and its dependencies, follow these steps:

# Clone the repository
git clone git@github.com:pvmeng/ML4MIP.git
cd ml4mip

# Create and activate a virtual environment (optional but recommended)
python -m venv venv
source venv/bin/activate  # On Windows use `venv\Scripts\activate`

# Install the package
pip install .

Alternatively, if you want to install it in editable mode for development:

pip install -e .

Usage

ML4MIP provides a set of command-line scripts for various tasks. You can list all available commands and their parameters using:

train --help

ml_flow_uri: file://${hydra:runtime.cwd}/runs
model_dir: ${hydra:runtime.cwd}/models
model_tag: unet.pt
batch_size: 16
lr: 0.01
num_epochs: 100
model:
  model_type: UNETMONAI2
  model_path: null
  base_model_jit_path: null
  checkpoint_path: null
dataset:
  train:
    data_dir: ${hydra:runtime.cwd}/data/rand_patch_96
    mask_dir: ${hydra:runtime.cwd}/data/rand_patch_96
    image_affix:
    - ''
    - .img.nii.gz
    mask_affix:
    - ''
    - .label.nii.gz
    transform: PATCH_POS_CENTER
    size:
    - 96
    - 96
    - 96
    train: true
    split_ratio: 0.9
    target_pixel_dim:
    - 0.35
    - 0.35
    - 0.5
    target_spatial_size:
    - 600
    - 600
    - 280
    sigma_ratio: 0.1
    pos_center_prob: 0.75
    max_samples: null
    cache: false
    cache_pooling: 0
    mask_operation: STD
    max_epochs: 40
    grouped: true
  val:
    data_dir: /data/training_data
    mask_dir: /data/training_data
    image_affix:
    - ''
    - .img.nii.gz
    mask_affix:
    - ''
    - .label.nii.gz
    transform: STD
    size:
    - 96
    - 96
    - 96
    train: true
    split_ratio: 0.9
    target_pixel_dim:
    - 0.35
    - 0.35
    - 0.5
    target_spatial_size:
    - 600
    - 600
    - 280
    sigma_ratio: 0.1
    pos_center_prob: 0.75
    max_samples: 2
    cache: false
    cache_pooling: 0
    mask_operation: STD
    max_epochs: 1
    grouped: false
visualize_model: true
visualize_model_val_batches: 1
visualize_model_train_batches: 4
plot_3d: true
extract_graph: false
epoch_profiling_torch: false
epoch_profiling_cpy: false
inference:
  mode: SLIDING_WINDOW
  sw_size:
  - 96
  - 96
  - 96
  sw_batch_size: 4
  sw_overlap: 0.25
  model_input_size:
  - 96
  - 96
  - 96
loss:
  loss_type: CE_DICE
  lambda_dice: 1.0
  lambda_ce: 0.3
  cedice_batch: false
  alpha: 0.5
scheduler:
  scheheduler_type: LINEARLR
  linear_start_factor: 1.0
  linear_end_factor: 0.01
  linear_total_iters: null
  resume_schedule: true

Configuration

ML4MIP uses Hydra to manage configurations. All parameters can be modified in conf/config.yaml. Users can override settings directly via the command line:

train batch_size=16 lr=0.01 num_epochs=100

Available Commands

train
validate
inference
preprocessing
extract_graph
postprocessing

For dataset settings and advanced configurations, refer to conf/config.yaml.

Logging and Experiment Tracking

ML4MIP integrates with MLFlow for tracking experiments. Logs, model checkpoints, and evaluation metrics are stored under ml_flow_uri, which can be set in the configuration.

Repository Structure

src/          # Main codebase for pipeline components and workflows  
experiments/  # Jupytext notebooks for various experimental analyses  

U-Net architecture

flowchart LR
    %% Skip connections
    C -- skip1 --> I
    D -- skip2 --> H
    E -- skip3 --> G

    A[**Input**<br/>*1 channel*] --> B[firstBlock<br/>*Conv3d → BatchNorm3d → ReLU*]
    B --> C[**en1**<br/>*Conv3d → BatchNorm3d → ReL*U]
    C --> D[**en2**<br/>*MaxPool3d → #40;Conv3d → BatchNorm3d → ReLU#41; ×2*]
    D --> E[**en3**<br/>*MaxPool3d → #40;Conv3d → BatchNorm3d → ReLU#41; ×2*]
    E --> F[**valley**<br/>*MaxPool3d → #40;Conv3d → BatchNorm3d → ReLU#41; ×2 → ConvTranspose3d*]
    F --> G[**dec1**<br/*>*#40;Conv3d → BatchNorm3d → ReLU#41; ×2 → ConvTranspose3d*]
    G --> H[**dec2**<br/>*#40;Conv3d → BatchNorm3d → ReLU#41; ×2 → ConvTranspose3d*]
    H --> I[**dec3**<br/>*#40;Conv3d → BatchNorm3d → ReLU#41; ×2 → Conv3d*]
    I --> J[**Output**<br/>*Raw score*]
Loading

About

Machine Learning for Medical Image Processing. Project done for Charité, the university hospital of Berlin. The aim was to segment coronary arteries and then extract a graph from it, in order to aid detection of coronary artery disease (CAD).

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%