This repository contains the source code of multi-task and baseline variational networks from our manuscript, "Multi-Task Accelerated MR Reconstruction Schemes for Jointly Training Multiple Contrasts", introduced by Victoria Liu, Kanghyun Ryu, Cagan Alkan, John Pauly, and Shreyas Vasanawala. The manuscript has been accepted to the 2021 NeuIPS Deep Inverse Workshop; check out a preprint here while we work on the workshop-ready version!
Model-based accelerated MRI reconstruction methods leverage large datasets to reconstruct diagnostic-quality images from undersampled k-space. These networks require matching training and test time distributions to achieve high quality recon- structions. However, there is inherent variability in MR datasets, including different contrasts, orientations, anatomies, and institution-specific protocols. The current paradigm is to train separate models for each dataset. However, this is a demanding process and cannot exploit information that may be shared amongst datasets. To address this issue, we propose multi-task learning (MTL) schemes that can jointly reconstruct multiple datasets. We test multiple MTL architectures and weighted loss functions against single task learning (STL) baselines. Our quantitative and qualitative results suggest that MTL can outperform STL across a range of dataset ratios for two knee contrasts.
Figure 1: Multi-task learning network architectures built on an unrolled variational network for (a) fully split and (b) shared-encoder-split-decoder blocks.
- Clone this repo:
git clone https://github.com/liuvictoria/multitasklearning.git
-
Create a virtual environment using
conda-env.txt
. Install packages inrequirements.txt
-
Download MRI data from mridata.org. We use the PDw and PDw-FS coronal volumes.
In the multiTaskLearning folder, you'll be able to run STL, transfer learning, MTL, and evaluation experiments from the command line. There are required and optional user-defined values that are passed in as flags. Here are some examples:
python3 stl.py --datasets div_coronal_pd_fs div_coronal_pd --epochs 120 --experimentname STL_joint --savefreq 10 --device cuda:2 --scarcities 1 2 3
python3 stl.py --datasets div_coronal_pd --epochs 120 --experimentname STL_nojoint --savefreq 10 --device cuda:2 --scarcities 1 2 3 --mixeddata 0
python3 mtl.py --datasets div_coronal_pd_fs div_coronal_pd --weighting uncert --blockstructures trueshare trueshare mhushare mhushare mhushare mhushare mhushare mhushare mhushare mhushare mhushare mhushare --epochs 120 --experimentname MTL_uncert_mhushare --device cuda:2 --scarcities 1 2 3 --savefreq 10
python3 mtl.py --datasets div_coronal_pd_fs div_coronal_pd --weighting naive --blockstructures trueshare trueshare split split split split split split split split split split --epochs 120 --experimentname MTL_naive_split --device cuda:1 --scarcities 1 2 3 --savefreq 10
python3 stl.py --datasets div_coronal_pd_fs --epochs 120 --experimentname STL_transfer5e-4 --savefreq 10 --device cuda:1 --scarcities 1 2 3 --lr 0.0005 --weightsdir /mnt/dense/vliu/summer_runs_models/models/STL_baselines/STL_nojoint_varnet_div_coronal_pd/N=481_l1.pt --mixeddata 0
python3 evaluate.py --datasets div_coronal_pd_fs div_coronal_pd --stratified 0 --network varnet --shareetas 1 --mixeddata 1 --scarcemax 497 --device cuda:1 --plotdir neurips/MTL_naive_split_varnetIIVVVVVVVVVV_div_coronal_pd_fs_div_coronal_pd --experimentnames MTL_naive_split --blockstructures IIVVVVVVVVVV --showbaselines 1
In addition, there is a script to delete aborted runs (usually due to GPU issue) and a script to visualize the undersampled images before any training:
python3 deleteDuplicateRuns.py --runnames MTL_uncert_gradacc_varnetIIYYYYYYYYYY_div_coronal_pd_fs_div_coronal_pd --rundir ../../../mnt/dense/vliu/runs/one_eta_pdfs_scarce
python3 getusamp.py --datasets div_coronal_pd_fs div_coronal_pd --device cuda:2 --plotdir usamp
The ESPIRiT folder contains scripts to estimate sensitivity maps.
In the following tables, we list some flags that may be of interest to you. The set of required flags differ amongst scripts, and there may be default values for some of the flags. Please consult the extensive documentation within the code for this additional information.
Flag Name | Usage | Relevant in |
---|---|---|
epochs |
number of epochs to run | mtl, stl |
lr |
learning rate | mtl, stl |
gradaccumulation |
how many iterations per gradient accumulation | mtl, stl |
Flag Name | Usage | Relevant in |
---|---|---|
blockstructures |
explicit list of what each unrolled block's convolutional network will be | mtl, stl, evaluate |
shareetas |
whether to share data consistency term eta | mtl, stl, evaluate |
weighting |
naive, uncert, dwa | mtl, stl, evaluate |
device |
cuda:2 device default | mtl, stl, evaluate |
Flag Name | Usage | Relevant in |
---|---|---|
datasets |
names of relevant tasks | mtl, stl, evaluate |
datadir |
data root directory; where are datasets contained | mtl, stl, evaluate |
scarcities |
number of samples in second task will be decreased by 1/2^N | mtl, stl |
accelerations |
list of undersampling factor of k-space for training | mtl, stl, evaluate |
centerfracs |
list of center fractions sampled of k-space for training | mtl, stl, evaluate |
Flag Name | Usage | Relevant in |
---|---|---|
stratified |
whether to use a stratified dataloader | mtl, stl |
numworkers |
number of workers for PyTorch dataloader | mtl, stl, evaluate |
Flag Name | Usage | Relevant in |
---|---|---|
experimentname |
experiment name for training only | mtl, stl |
verbose |
verbosity on terminal | mtl, stl |
tensorboard |
whether to create tensorboard | mtl, stl, evaluate |
savefreq |
how frequently to save recon image | mtl, stl, evaluate |
Flag Name | Usage |
---|---|
colors |
https://docs.bokeh.org/en/latest/docs/reference/palettes.html#bokeh-palette |
createplots |
whether to create plots of metrics for different ratios of MRI |
showbaselines |
whether to show (pre-calculated) STL baselines in plot |
showbest |
whether to show best mtl run |
You can explore how MR images improve throughout training epochs. The top left is the reconstructed image, and the top right is the ground truth. Bottom row contains error maps.
The Bokeh plots are interactive. You can zoom in, see exact values by hovering over a point, and declutter the plot by clicking the legend.
At the inference step, you can also scroll through entire MRI volumes to see how reconstruction affects different slices.
We thank the Stanford MRSRL group for their insightful comments during group meetings. We also thank Hammernik et al. for making the MRI data publicly available. We are indebted to the volunteers who consented to have their MRIs available for research.
If you found our repository or paper helpful, please cite:
{to be updated on October 15 when the preprint is released}
If you have any questions or comments, please contact me at vliu@caltech.edu