Skip to content

Commit

Permalink
Merge pull request #38 from MELDProject/dev_docker
Browse files Browse the repository at this point in the history
Updates JAMA revisions
  • Loading branch information
kwagstyl authored Aug 29, 2024
2 parents 6256c95 + 3953402 commit f4ecbe6
Show file tree
Hide file tree
Showing 45 changed files with 4,378 additions and 7,855 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 1.0.0
current_version = 2.1.0
commit = True
tag = True
files = setup.py meld_graph/__init__.py
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

Graph based FCD lesion segmentation for the [MELD project](https://meldproject.github.io/).

This package offers a friendly user pipeline to segment FCD-lesions from MRI scans.
This package is a pipeline to segment FCD-lesions from MRI scans.

![overview](https://raw.githubusercontent.com//MELDProject/meld_graph/dev_docker/docs/images/Fig1_pipeline.jpg)

Expand Down
1 change: 1 addition & 0 deletions data/feature_means_no_combat_mrineghisto.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{".combat.on_lh.pial.K_filtered.sm20.mgh": {"mean": 0.5038626194000244, "std": 0.3130382299423218}, ".combat.on_lh.thickness.sm3.mgh": {"mean": 2.6667654514312744, "std": 0.6937940716743469}, ".combat.on_lh.thickness_regression.sm3.mgh": {"mean": 2.6667654514312744, "std": 0.6705865263938904}, ".combat.on_lh.w-g.pct.sm3.mgh": {"mean": 24.170522689819336, "std": 9.622745513916016}, ".combat.on_lh.sulc.sm3.mgh": {"mean": 0.05367401987314224, "std": 0.5665598511695862}, ".combat.on_lh.curv.sm3.mgh": {"mean": -0.02662261761724949, "std": 0.14215724170207977}, ".combat.on_lh.gm_FLAIR_0.75.sm3.mgh": {"mean": 136.82968139648438, "std": 17.612781524658203}, ".combat.on_lh.gm_FLAIR_0.5.sm3.mgh": {"mean": 140.94927978515625, "std": 16.043102264404297}, ".combat.on_lh.gm_FLAIR_0.25.sm3.mgh": {"mean": 137.48272705078125, "std": 14.159014701843262}, ".combat.on_lh.gm_FLAIR_0.sm3.mgh": {"mean": 128.0760040283203, "std": 10.51087760925293}, ".combat.on_lh.wm_FLAIR_0.5.sm3.mgh": {"mean": 119.6790542602539, "std": 7.96249532699585}, ".combat.on_lh.wm_FLAIR_1.sm3.mgh": {"mean": 114.50215148925781, "std": 7.1296820640563965}, ".inter_z.intra_z.combat.on_lh.pial.K_filtered.sm20.mgh": {"mean": 0.020965227857232094, "std": 1.1048777103424072}, ".inter_z.intra_z.combat.on_lh.thickness_regression.sm3.mgh": {"mean": 0.007747821509838104, "std": 1.066899299621582}, ".inter_z.intra_z.combat.on_lh.w-g.pct.sm3.mgh": {"mean": 0.007210636977106333, "std": 1.0602881908416748}, ".inter_z.intra_z.combat.on_lh.sulc.sm3.mgh": {"mean": 0.006064554676413536, "std": 1.0923455953598022}, ".inter_z.intra_z.combat.on_lh.curv.sm3.mgh": {"mean": -0.002045442583039403, "std": 1.0671255588531494}, ".inter_z.intra_z.combat.on_lh.gm_FLAIR_0.75.sm3.mgh": {"mean": -0.04725680127739906, "std": 1.263229489326477}, ".inter_z.intra_z.combat.on_lh.gm_FLAIR_0.5.sm3.mgh": {"mean": -0.02844412811100483, "std": 1.2561675310134888}, ".inter_z.intra_z.combat.on_lh.gm_FLAIR_0.25.sm3.mgh": {"mean": -0.006390768568962812, "std": 1.2389447689056396}, ".inter_z.intra_z.combat.on_lh.gm_FLAIR_0.sm3.mgh": {"mean": 0.010225020349025726, "std": 1.199887990951538}, ".inter_z.intra_z.combat.on_lh.wm_FLAIR_0.5.sm3.mgh": {"mean": 0.020151548087596893, "std": 1.1929118633270264}, ".inter_z.intra_z.combat.on_lh.wm_FLAIR_1.sm3.mgh": {"mean": 0.03191414847970009, "std": 1.1793659925460815}, ".inter_z.asym.intra_z.combat.on_lh.pial.K_filtered.sm20.mgh": {"mean": -4.3655745685100555e-10, "std": 1.1215183734893799}, ".inter_z.asym.intra_z.combat.on_lh.thickness_regression.sm3.mgh": {"mean": -1.4551915228366852e-10, "std": 1.0536749362945557}, ".inter_z.asym.intra_z.combat.on_lh.w-g.pct.sm3.mgh": {"mean": -8.149072527885437e-10, "std": 1.0781065225601196}, ".inter_z.asym.intra_z.combat.on_lh.sulc.sm3.mgh": {"mean": 6.548361852765083e-11, "std": 1.1036579608917236}, ".inter_z.asym.intra_z.combat.on_lh.curv.sm3.mgh": {"mean": 4.3655745685100555e-11, "std": 1.0685703754425049}, ".inter_z.asym.intra_z.combat.on_lh.gm_FLAIR_0.75.sm3.mgh": {"mean": 2.3283064365386963e-10, "std": 1.2068378925323486}, ".inter_z.asym.intra_z.combat.on_lh.gm_FLAIR_0.5.sm3.mgh": {"mean": -4.656612873077393e-10, "std": 1.2048578262329102}, ".inter_z.asym.intra_z.combat.on_lh.gm_FLAIR_0.25.sm3.mgh": {"mean": -2.3283064365386963e-10, "std": 1.1929599046707153}, ".inter_z.asym.intra_z.combat.on_lh.gm_FLAIR_0.sm3.mgh": {"mean": 3.4924596548080444e-10, "std": 1.1686742305755615}, ".inter_z.asym.intra_z.combat.on_lh.wm_FLAIR_0.5.sm3.mgh": {"mean": 0.0, "std": 1.170699119567871}, ".inter_z.asym.intra_z.combat.on_lh.wm_FLAIR_1.sm3.mgh": {"mean": 1.5279510989785194e-10, "std": 1.1748234033584595}}
Binary file removed data/icospheres/ico7.spirals.npy
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
8 changes: 4 additions & 4 deletions docs/figure_notebooks.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

**Figures:**\
- Figure 1. Overview\
- [Figure 2. Examples of predictions](/notebooks/plot_examples_prediction.ipynb)\
- [Figure 2A. Comparison of predictions & cluster numbers](/notebooks/plot_examples_prediction.ipynb)\
- [Figure 2B. Comparison cluster numbers](/notebooks/compare_results_MLP.ipynb)\
- [Figure 3. Examples of interpretable patient reports](/notebooks/plot_examples_reports.ipynb)

**Tables:**\
Expand All @@ -11,12 +12,11 @@

**eTables:**\
- [eTable 1. Demographics](/notebooks/demographics_table.ipynb)\
- [eTable 2. Reduction in false positive clusters](/notebooks/plot_clusters_number.ipynb)\
- [eTable 2. Reduction in false positive clusters](/notebooks/compare_results_MLP.ipynb)\
- [eTable 3. MELD Graph model with harmonized vs non-harmonized MRI features](/notebooks/compare_results_MLP.ipynb)

**eFigures:**\
- eFigure 1. MELD Graph model architecture\
- eFigure 2. Pipeline for running a new patient’s MRI scan through MELD Graph\
- [eFigure 3. Stability of NeuroCombat](/notebooks/combat_subsampling.ipynb)\
- [eFigure 4. Characterisation of detected FCD lesions.](/notebooks/analysis_saliencies.ipynb)\
- [eFigure 5. Calibration of confidence scores](/notebooks/plot_confidence_calibration.ipynb)
- [eFigure 4. Characterisation of predictions](/notebooks/analysis_predictions.ipynb)\
2 changes: 1 addition & 1 deletion meld_graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__author__ = __maintainer__ = "MELD development team"
__email__ = "meld.study@gmail.com"
__version__ = "1.0.0"
__version__ = "2.1.0"
38 changes: 22 additions & 16 deletions meld_graph/confidence.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ def calibration_plot(results_dict, n_bins=10, confidence='confidence_lesion'):
ax.plot([0,1],[0,1], '--')
ax.set_ylim(0,1)
ax.set_xlim(0,1)
ax.set_xlabel(confidence)
ax.set_ylabel('frequency of TPs')
ax.set_title('Per vertex confidence (ECE: {:.2f})'.format(ece))
ax.set_xlabel(confidence, fontsize=15)
ax.set_ylabel('frequency of TPs', fontsize=15)
ax.set_title('Per vertex confidence (ECE: {:.2f})'.format(ece), fontsize=12)
return fig

def calculate_per_cluster_confidence(results_dict, aggregation_fn='median', mask_by_saliency=False, eva=None, prediction_suffix=""):
Expand Down Expand Up @@ -147,7 +147,7 @@ def calculate_per_cluster_confidence(results_dict, aggregation_fn='median', mask
return results_df


def cluster_calibration_plot(confidence, label, n_bins=10):
def cluster_calibration_plot(confidence, label, n_bins=10,ax=None):
"""
calculate ECE as described in literature
calclulate calibration plot as calculated by sklearn.calibration
Expand Down Expand Up @@ -180,24 +180,30 @@ def cluster_calibration_plot(confidence, label, n_bins=10):
ece = np.nansum(np.abs(np.array(freq) - np.array(conf))*np.array(n))
#print('ECE: ', ece)
mask=np.array(mask)
fig, ax = plt.subplots(1,1, figsize=(5,5))
ax.plot((bins[:-1] + (bins[1:]-bins[:-1])/2)[mask], np.array(freq)[mask], 'o-')
if ax is None:
fig, ax = plt.subplots(1,1, figsize=(5,5))
#ax.plot((bins[:-1] + (bins[1:]-bins[:-1])/2)[mask], np.array(freq)[mask], 'o-')
ax.bar(bins[:-1] + (bins[1:]-bins[:-1])/2, n, width=0.05, color='black', alpha=0.5)
#changes
#ax.scatter((bins[:-1] + (bins[1:]-bins[:-1])/2)[mask], np.array(freq)[mask], 'o')
sns.regplot(x=bins[:-1] + (bins[1:]-bins[:-1])/2, y=freq, ax=ax, scatter_kws={'s': 10})
# plot line for perfect calibration
ax.plot([0,1],[0,1], '--')
ax.set_ylim(0,1)
ax.set_xlim(0,1)
ax.set_xlabel('confidence')
ax.set_ylabel('frequency of TPs')
ax.set_xlabel('Confidence', fontsize=15)
ax.set_ylabel('Frequency of TPs', fontsize=15)
ax.set_title('Per cluster confidence (ECE: {:.2f})'.format(ece))
return fig
return ax

def confidence_label_distplot(per_cluster_confidence, per_cluster_label):
def confidence_label_distplot(per_cluster_confidence, per_cluster_label,ax=None):
# quick check to see if this roughly makes sense - FP clusters should have low confidence
fig, ax = plt.subplots(1,1)
sns.kdeplot(x=per_cluster_confidence[per_cluster_label==0], bw_adjust=0.8, ax=ax, label='FP')
sns.kdeplot(x=per_cluster_confidence[per_cluster_label==1], bw_adjust=0.8, ax=ax, label='TP')
plt.legend()
ax.set_xlabel('confidence')
if ax is None:
fig, ax = plt.subplots(1,1)
sns.kdeplot(x=per_cluster_confidence[per_cluster_label==0], bw_adjust=0.5, ax=ax, label='False Positives')
sns.kdeplot(x=per_cluster_confidence[per_cluster_label==1], bw_adjust=0.5, ax=ax, label='True Positives')
ax.legend()
ax.set_xlabel('Confidence', fontsize=15)
ax.set_ylabel('Density', fontsize=15)
ax.set_xlim([0,1])
return fig
return ax
1 change: 1 addition & 0 deletions meld_graph/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__(
for s_i, subj_id in enumerate(self.subject_ids):
# load in (control) data
# features are appended to list in order: left, right
print(subj_id)
subject_data_list = self.prep.get_data_preprocessed(
subject=subj_id,
features=params["features"],
Expand Down
77 changes: 70 additions & 7 deletions meld_graph/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
checkpoint_path=None,
make_images=False,
thresh_and_clust=True,
threshold="two_threshold",
threshold="slope_threshold",
min_area_threshold=100,
dataset=None,
cohort=None,
Expand Down Expand Up @@ -143,6 +143,22 @@ def __init__(
else:
print(f"Could not find an optimised threshold file at {threshold_file}. You need to run script optimise_sigmoid_trainval.py")
return
elif threshold == "multi_threshold":
self.threshold_mode = 'multi_threshold'
self.threshold = list([0.01, 0.1, 0.2, 0.3, 0.4, 0.5])
self.log.info("Evaluation {}, min area threshold={}, prediction thresholds={})".format(
self.mode,self.min_area_threshold, self.threshold))
elif threshold == "max_threshold":
self.threshold_mode = 'max_threshold'
self.threshold = {'min_thresh':0.01, 'max_thresh':0.5, 'margin': 0.25}
self.log.info("Evaluation {}, min area threshold={}, prediction thresholds={})".format(
self.mode,self.min_area_threshold, self.threshold))
elif threshold == "slope_threshold":
self.threshold_mode = 'slope_threshold'
self.threshold = {'min_thresh':0.01, 'max_thresh':0.5, 'slope': 0.2}
self.log.info("Evaluation {}, min area threshold={}, prediction threshold={})".format(
self.mode,self.min_area_threshold, self.threshold))

elif isinstance(threshold, float):
self.threshold_mode = 'threshold'
self.threshold = threshold
Expand Down Expand Up @@ -382,7 +398,7 @@ def roc_dictionary(self):
"specificity": np.zeros(len(self.thresholds)),
}
return self._roc_dictionary

def threshold_and_cluster(self, data_dictionary=None, save_prediction=True, save_prediction_suffix=""):
# helper fn getting the clustered and thresholded data for a given threshold
def get_cluster_thresholded(predictions, threshold_subj):
Expand Down Expand Up @@ -410,18 +426,67 @@ def get_cluster_thresholded(predictions, threshold_subj):
threshold_subj = self.threshold

predictions = self.experiment.cohort.split_hemispheres(data["result"])
if self.threshold_mode == 'two_threshold':
if self.threshold_mode == 'multi_threshold':
#list of multiple thresholds fixed
# initialise best threshold to be the smallest one
thresholds = np.sort(threshold_subj)[::-1]
best_threshold = thresholds[-1]
# loop over descending thresholds and keep the highest threshold that give a prediction
for threshold in thresholds:
if data["result"].max() > threshold:
cluster_thresholded = get_cluster_thresholded(predictions, threshold)
if cluster_thresholded.sum() > 0 :
best_threshold = threshold
break
cluster_thresholded = get_cluster_thresholded(predictions, best_threshold)
print(f"threshold_subj = {best_threshold}")
data["threshold"] = best_threshold
data["cluster_thresholded"] = cluster_thresholded
if self.threshold_mode == 'max_threshold':
# Threshold = max_thresh if max(predictions) > max_thresh
# else Threshold = max(max(predictions)-margin,min_thresh)
best_threshold = 0
if data["result"].max() > threshold_subj['max_thresh']:
threshold = threshold_subj['max_thresh']
cluster_thresholded = get_cluster_thresholded(predictions, threshold)
if cluster_thresholded.sum() > 0 :
best_threshold = threshold
if best_threshold==0:
threshold = max((min(threshold_subj['max_thresh']-threshold_subj['margin'], data["result"].max() - threshold_subj['margin'])), threshold_subj['min_thresh'])
threshold = min(threshold_subj['max_thresh']/2, max(data["result"].max()/2, 0.01))
cluster_thresholded = get_cluster_thresholded(predictions, threshold)
if cluster_thresholded.sum() > 0 :
best_threshold = threshold
if best_threshold==0:
best_threshold = 0.01
cluster_thresholded = get_cluster_thresholded(predictions, best_threshold)
print(f"threshold_subj = {best_threshold}")
data["threshold"] = best_threshold
data["cluster_thresholded"] = cluster_thresholded
if self.threshold_mode == 'slope_threshold':
m = data["result"].max()
if (data["result"]>=threshold_subj['max_thresh']).sum()>100:
best_threshold = threshold_subj['max_thresh']
else:
best_threshold = np.max([data["result"].max()*threshold_subj['slope'],threshold_subj['min_thresh']])
cluster_thresholded = get_cluster_thresholded(predictions, best_threshold)
print(f"threshold_subj = {best_threshold}")
data["threshold"] = best_threshold
data["cluster_thresholded"] = cluster_thresholded
elif self.threshold_mode == 'two_threshold':
#thresholds are optimised from the trainval
print(f'using thresholds {threshold_subj}')
pred_low_confidence = get_cluster_thresholded(predictions, threshold_subj[0])
pred_high_confidence = get_cluster_thresholded(predictions, threshold_subj[1])
data["cluster_thresholded_low_conf"] = pred_low_confidence
data["cluster_thresholded_high_conf"] = pred_high_confidence
#data["cluster_thresholded"] = pred_high_confidence if data["result"].max() > threshold_subj[1] else pred_low_confidence
data["cluster_thresholded"] = pred_high_confidence if pred_high_confidence.sum()>0 else pred_low_confidence
if data["result"].max() > threshold_subj[1]:
print(f"threshold_subj = {threshold_subj[1]}")
data["threshold"] = threshold_subj[1]
else:
print(f"threshold_subj = {threshold_subj[0]}")
data["threshold"] = threshold_subj[0]
else:
data["cluster_thresholded"] = get_cluster_thresholded(predictions, threshold_subj)
if save_prediction:
Expand Down Expand Up @@ -596,6 +661,7 @@ def _load_data_from_file(subj_id):
mask_salient[vertices_salient]= True
#rearange saliencies and mask salient in whole brain - add empty hemi
empty_hemi = np.zeros(cur_saliency.shape)

if hemi=='left':
saliency_vert[subj_id][cl] = np.hstack([cur_saliency[self.experiment.cohort.cortex_mask,:].T,empty_hemi[self.experiment.cohort.cortex_mask,:].T]).T
mask_salient_vert[subj_id][cl] = np.hstack([mask_salient[self.experiment.cohort.cortex_mask],empty_hemi[self.experiment.cohort.cortex_mask, 0]])
Expand Down Expand Up @@ -634,9 +700,6 @@ def _load_data_from_file(subj_id):
def stat_subjects(self, suffix="", fold=None):
"""calculate stats for each subjects"""
suffix = f"{suffix}{self.dropout_suffix}"
# TODO: need to add boundaries
# boundary_label = MeldSubject(subject, self.experiment.cohort).load_boundary_zone(max_distance=20)

# calculate stats on thresholded and clustered predictions
for subject in self.data_dictionary.keys():
# use prediction clustered
Expand Down
13 changes: 12 additions & 1 deletion meld_graph/tools_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import os
import bids.layout
import pandas as pd
from subprocess import Popen
from meld_graph.paths import MELD_DATA_PATH

Expand Down Expand Up @@ -97,4 +98,14 @@ def run_command(command, verbose=False):
# )
# if (proc.stdout) and (verbose):
# print(get_m("Result: {}".format(proc.stdout.decode('utf-8')), None, 'COMMAND'))
return proc
return proc

def create_demographic_file(subjects_ids, save_file, harmo_code='noHarmo'):
df = pd.DataFrame()
if isinstance(subjects_ids, str):
subjects_ids=[subjects_ids]
df['ID']=subjects_ids.astype(str)
df['Harmo code']=[str(harmo_code) for subject in subjects_ids]
df['Group']=['patient' for subject in subjects_ids]
df['Scanner']=['3T' for subject in subjects_ids]
df.to_csv(save_file)
1 change: 1 addition & 0 deletions meldsetup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# create the meld graph environment with all the dependencies
conda env create -f environment-mac.yml
# activate the environment
eval "$(conda shell.bash hook)"
conda activate meld_graph
# install meld_graph with pip (with `-e`, the development mode, to allow changes in the code to be immediately visible in the installation)
pip install -e .
Loading

0 comments on commit f4ecbe6

Please sign in to comment.