Skip to content

Commit

Permalink
Merge changes
Browse files Browse the repository at this point in the history
  • Loading branch information
GILIYAR RADHAKRISHNA Chaithya committed Jan 10, 2024
2 parents 1651cec + 55b6764 commit 6fc7d95
Show file tree
Hide file tree
Showing 45 changed files with 1,246 additions and 529 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/ci-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ jobs:
python -m pip install git+https://github.com/CEA-COSMIC/ModOpt.git@develop
python -m pip install git+https://github.com/CEA-COSMIC/pysap.git@develop
python -m pip install git+https://github.com/AGrigis/pysphinxdoc.git
python -m pip install coverage nose pytest pytest-cov pycodestyle twine pytest-xdist
python -m pip install coverage nose pytest pytest-cov pycodestyle pydocstyle twine pytest-xdist
python -m pip install pynfft2
python -m pip install --upgrade .
Expand All @@ -57,7 +57,8 @@ jobs:
run: |
pycodestyle mri --ignore="E121,E123,E126,E226,E24,E704,E402,E731,E722,E741,W503,W504,W605"
pycodestyle examples --ignore="E121,E123,E126,E226,E24,E704,E402,E731,E722,E741,W503,W504,W605"
pydocstyle mri --convention=numpy
pydocstyle examples --convention=numpy
- name: Run Tests
shell: bash -l {0}
run: |
Expand Down
21 changes: 21 additions & 0 deletions doc/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,24 @@ @article{man1997
year = {1997},
pages = {785-792},
}

@article{Pruessmann1999,
title={SENSE: Sensitivity encoding for fast MRI},
author={Klaas Paul Pruessmann and Markus Weiger and Markus B. Scheidegger and Peter Boesiger},
journal={Magnetic Resonance in Medicine},
year={1999},
volume={42},
}
@inproceedings{Donoho1994,
address = {Baltimore, MD, USA},
title = {Threshold selection for wavelet shrinkage of noisy data},
ISBN = {978-0-7803-2050-5},
url = {http://ieeexplore.ieee.org/document/412133/},
DOI = {10.1109/IEMBS.1994.412133},
booktitle = {Proceedings of 16th Annual International Conference of the
IEEE Engineering in Medicine and Biology Society},
publisher = {IEEE},
author = {Donoho, D.L. and Johnstone, I.M.},
year = 1994,
pages = {A24–A25}
}
225 changes: 225 additions & 0 deletions examples/cartesian_reconstruction_auto_threshold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
#!/usr/bin/env python
# coding: utf-8

#
# Neuroimaging cartesian reconstruction
# =====================================
#
# Author: Chaithya G R / Pierre-Antoine Comby
#
# In this tutorial we will reconstruct an MRI image from the sparse kspace
# measurements.
#
# Import neuroimaging data
# ------------------------
#
# We use the toy datasets available in pysap, more specifically a 2D brain slice
# and the cartesian acquisition scheme.
#

# In[1]:


import matplotlib.pyplot as plt
import numpy as np
from modopt.math.metrics import snr, ssim
from modopt.opt.linear import Identity
# Third party import
from modopt.opt.proximity import SparseThreshold
from mri.operators import FFT, WaveletN
from mri.operators.proximity.weighted import AutoWeightedSparseThreshold
from mri.operators.utils import convert_mask_to_locations
from mri.reconstructors import SingleChannelReconstructor
from pysap.data import get_sample_data

image = get_sample_data('2d-mri')
print(image.data.min(), image.data.max())
image = image.data
image /= np.max(image)
mask = get_sample_data("cartesian-mri-mask")


# Get the locations of the kspace samples
kspace_loc = convert_mask_to_locations(mask.data)
# Generate the subsampled kspace
fourier_op = FFT(mask=mask, shape=image.shape)
kspace_data = fourier_op.op(image)

# Zero order solution
image_rec0 = np.abs(fourier_op.adj_op(kspace_data))

# Calculate SSIM
base_ssim = ssim(image_rec0, image)
print(base_ssim)

#%%
# POGM optimization
# ------------------
# We now want to refine the zero order solution using an accelerated Proximal Gradient
# Descent algorithm (FISTA or POGM).
# The cost function is set to Proximity Cost + Gradient Cost

# In[4]:


# Setup the operators
linear_op = WaveletN(wavelet_name="sym8", nb_scales=3)

# Manual tweak of the regularisation parameter
regularizer_op = SparseThreshold(Identity(), 2e-3, thresh_type="soft")
# Setup Reconstructor
reconstructor = SingleChannelReconstructor(
fourier_op=fourier_op,
linear_op=linear_op,
regularizer_op=regularizer_op,
gradient_formulation='synthesis',
verbose=1,
)
# Start Reconstruction
x_final, costs, metrics = reconstructor.reconstruct(
kspace_data=kspace_data,
optimization_alg='pogm',
num_iterations=100,
cost_op_kwargs={"cost_interval":None},
metric_call_period=1,
metrics = {
"snr":{
"metric": snr,
"mapping": {"x_new":"test"},
"cst_kwargs": {"ref": image},
"early_stopping":False,
},
"ssim":{
"metric": ssim,
"mapping": {"x_new":"test"},
"cst_kwargs": {"ref": image},
"early_stopping": False,
}
}
)

image_rec = np.abs(x_final)
# image_rec.show()
# Calculate SSIM
recon_ssim = ssim(image_rec, image)
recon_snr= snr(image_rec, image)

print('The Reconstruction SSIM is : ' + str(recon_ssim))
print('The Reconstruction SNR is : ' + str(recon_snr))

#%%
# Threshold estimation using SURE
# -------------------------------

_w = None

def static_weight(w, idx):
print(np.unique(w))
return w

# Setup the operators
linear_op = WaveletN(wavelet_name="sym8", nb_scale=3,padding_mode="periodization")
coeffs = linear_op.op(image_rec0)
print(linear_op.coeffs_shape)

# Here we don't manually setup the regularisation weights, but use statistics on the wavelet details coefficients

regularizer_op = AutoWeightedSparseThreshold(
linear_op.coeffs_shape, linear=Identity(),
update_period=0, # the weight is updated only once.
sigma_range="global",
thresh_range="global",
threshold_estimation="sure",
thresh_type="soft",
)
# Setup Reconstructor
reconstructor = SingleChannelReconstructor(
fourier_op=fourier_op,
linear_op=linear_op,
regularizer_op=regularizer_op,
gradient_formulation='synthesis',
verbose=1,
)
# Start Reconstruction
x_final, costs, metrics2 = reconstructor.reconstruct(
kspace_data=kspace_data,
optimization_alg='pogm',
num_iterations=100,
metric_call_period=1,
cost_op_kwargs={"cost_interval":None},
metrics = {
"snr":{
"metric": snr,
"mapping": {"x_new":"test"},
"cst_kwargs": {"ref": image},
"early_stopping":False,
},
"ssim":{
"metric": ssim,
"mapping": {"x_new":"test"},
"cst_kwargs": {"ref": image},
"early_stopping": False,
},
"cost_grad":{
"metric": lambda x: reconstructor.gradient_op.cost(linear_op.op(x)),
"mapping": {"x_new":"x"},
"cst_kwargs": {},
"early_stopping": False,
},
"cost_prox":{
"metric": lambda x: reconstructor.prox_op.cost(linear_op.op(x)),
"mapping": {"x_new":"x"},
"cst_kwargs": {},
"early_stopping": False,
}
}
)
image_rec2 = np.abs(x_final)
# image_rec.show()
# Calculate SSIM
recon_ssim2 = ssim(image_rec2, image)
recon_snr2 = snr(image_rec2, image)

print('The Reconstruction SSIM is : ' + str(recon_ssim2))
print('The Reconstruction SNR is : ' + str(recon_snr2))

plt.subplot(121)
plt.plot(metrics["snr"]["time"], metrics["snr"]["values"], label="pogm classic")
plt.plot(metrics2["snr"]["time"], metrics2["snr"]["values"], label="pogm sure global")
plt.ylabel("snr")
plt.xlabel("time")
plt.legend()
plt.subplot(122)
plt.plot(metrics["ssim"]["time"], metrics["ssim"]["values"])
plt.plot(metrics2["ssim"]["time"], metrics2["ssim"]["values"])
plt.ylabel("ssim")
plt.xlabel("time")
plt.figure()
plt.subplot(121)
plt.plot(metrics["snr"]["index"], metrics["snr"]["values"])
plt.plot(metrics2["snr"]["index"], metrics2["snr"]["values"])
plt.ylabel("snr")
plt.subplot(122)
plt.plot(metrics["ssim"]["index"], metrics["ssim"]["values"])
plt.plot(metrics2["ssim"]["index"], metrics2["ssim"]["values"])


#%%
# Qualitative results
# -------------------
#
def my_imshow(ax, img, title):
ax.imshow(img, cmap="gray")
ax.set_title(title)
ax.axis("off")



fig, axs = plt.subplots(2,2)

my_imshow(axs[0,0], image, "Ground Truth")
my_imshow(axs[0,1], abs(image_rec0), f"Zero Order \n SSIM={base_ssim:.4f}")
my_imshow(axs[1,0], abs(image_rec), f"Fista Classic \n SSIM={recon_ssim:.4f}")
my_imshow(axs[1,1], abs(image_rec2), f"Fista Sure \n SSIM={recon_ssim2:.4f}")

fig.tight_layout()
8 changes: 4 additions & 4 deletions mri/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ class KspaceGeneratorBase:
Parameters
----------
full_kspace: np.ndarray
full_kspace: numpy.ndarray
The fully sampled kspace, which will be returned incrementally,
use for the Fourier transform.
mask: np.ndarray
mask: numpy.ndarray
A binary mask, giving the sampled location for the kspace
"""

Expand All @@ -30,9 +30,9 @@ def __init__(self, full_kspace: np.ndarray, mask: np.ndarray, max_iter: int = 1)
Parameters
-----------
full_kspace: ndarray
full_kspace: numpy.ndarray
The full kspace data
mask: ndarray
mask: numpy.ndarray
The mask for undersampling the k-space data.
max_iter: int
Maximum number of iterations to be yields.
Expand Down
2 changes: 1 addition & 1 deletion mri/generators/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class Column2DKspaceGenerator(KspaceGeneratorBase):
Parameters
----------
full_kspace: ndarray
full_kspace: numpy.ndarray
Complete kspace_data.
mask_cols: array_like
List of the column indices to use for the mask.
Expand Down
3 changes: 1 addition & 2 deletions mri/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
# for details.
##########################################################################

""" This module defines the common operators.
"""
"""This module defines the common operators."""

from .fourier.cartesian import FFT
from .fourier.non_cartesian import NonCartesianFFT
Expand Down
27 changes: 16 additions & 11 deletions mri/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,46 @@
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html for details. #
# #############################################################################

"""
Base Operator.
class OperatorBase(object):
""" Base Operator class. Every linear operator inherits from this class,
Every operator should have an `op` and `adj_op` methods.
"""

class OperatorBase:
"""Base Operator class.
Every linear operator inherits from this class,
to ensure that we have all the functions rightly implemented
as required by Modopt
"""

def op(self, data):
""" This method calculates operator transform.
"""Compute operator transform.
Parameters
----------
data: np.ndarray
data: numpy.ndarray
input as array.
Returns
-------
result: np.ndarray
result: numpy.ndarray
operator transform of the input.
"""

raise NotImplementedError("'op' is an abstract method.")

def adj_op(self, x):
""" This method calculates adjoint operator transform.
def adj_op(self, coeffs):
"""Compute adjoint operator transform.
Parameters
----------
x: np.ndarray
x: numpy.ndarray
input data array.
Returns
-------
results: np.ndarray
results: numpy.ndarray
adjoint operator transform.
"""

raise NotImplementedError("'adj_op' is an abstract method.")
Loading

0 comments on commit 6fc7d95

Please sign in to comment.