diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index e6843a95a..ed3ea469f 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -137,6 +137,8 @@ jobs: pip install torchkbnufft elif [[ ${{ matrix.backend }} == "tensorflow" ]]; then pip install tensorflow-mri==0.21.0 tensorflow-probability==0.17.0 tensorflow-io==0.27.0 matplotlib==3.7 + elif [[ ${{ matrix.backend }} == "cufinufft" ]]; then + pip install "cufinufft<2.3" else pip install ${{ matrix.backend }} fi @@ -213,7 +215,7 @@ jobs: export PATH=/usr/local/cuda-12.1/bin/:${PATH} export LD_LIBRARY_PATH=/usr/local/cuda-12.1/lib64/:${LD_LIBRARY_PATH} pip install cupy-cuda12x torch - python -m pip install gpuNUFFT cufinufft sigpy scikit-image + python -m pip install gpuNUFFT "cufinufft<2.3" sigpy scikit-image - name: Run examples shell: bash @@ -324,7 +326,7 @@ jobs: export PATH=/usr/local/cuda-12.1/bin/:${PATH} export LD_LIBRARY_PATH=/usr/local/cuda-12.1/lib64/:${LD_LIBRARY_PATH} pip install cupy-cuda12x torch - python -m pip install gpuNUFFT cufinufft + python -m pip install gpuNUFFT "cufinufft<2.3" - name: Build API documentation run: | diff --git a/docs/sphinx_add_colab_link.py b/docs/sphinx_add_colab_link.py index 26f6f5642..c247dd750 100644 --- a/docs/sphinx_add_colab_link.py +++ b/docs/sphinx_add_colab_link.py @@ -73,6 +73,7 @@ def notebook_modifier(self, notebook_path, commands): idx = self.find_index_of_colab_link(notebook) code_lines = ["# Install libraries"] code_lines.append(commands) + code_lines.append("pip install brainweb-dl # Required for data") dummy_notebook_content = {"cells": []} add_code_cell( dummy_notebook_content, diff --git a/examples/example_offresonance.py b/examples/example_offresonance.py new file mode 100644 index 000000000..a16c547cc --- /dev/null +++ b/examples/example_offresonance.py @@ -0,0 +1,110 @@ +""" +====================== +Off-resonance Corrected NUFFT Operator +====================== + +Example of Off-resonance Corrected NUFFT trajectory operator. + +This examples show how to use the Off-resonance Corrected NUFFT operator to acquire +and reconstruct data in presence of field inhomogeneities. +Here a spiral trajectory is used as a demonstration. + +""" + +import matplotlib.pyplot as plt +import numpy as np + +from mrinufft import display_2D_trajectory + +plt.rcParams["image.cmap"] = "gray" + +# %% +# Data Generation +# =============== +# For realistic 2D image we will use a slice from the brainweb dataset. +# installable using ``pip install brainweb-dl`` + +from brainweb_dl import get_mri + +mri_data = get_mri(0, "T1") +mri_data = mri_data[::-1, ...][90] +plt.imshow(mri_data), plt.axis("off"), plt.title("ground truth") + +# %% +# Masking +# =============== +# Here, we generate a binary mask to exclude the background. +# We perform a simple binary threshold; in real-world application, +# it is advised to use other tools (e.g., FSL-BET). + +brain_mask = mri_data > 0.1 * mri_data.max() +plt.imshow(brain_mask), plt.axis("off"), plt.title("brain mask") + +# %% +# Field Generation +# =============== +# Here, we generate a radial B0 field with the same shape of +# the input Shepp-Logan phantom + +from mrinufft.extras import make_b0map + +# generate field +b0map, _ = make_b0map(mri_data.shape, b0range=(-200, 200), mask=brain_mask) +plt.imshow(brain_mask * b0map, cmap="bwr", vmin=-200, vmax=200), plt.axis( + "off" +), plt.colorbar(), plt.title("B0 map [Hz]") + +# %% +# Generate a Spiral trajectory +# ---------------------------- + +from mrinufft import initialize_2D_spiral +from mrinufft.density import voronoi +from mrinufft.trajectories.utils import DEFAULT_RASTER_TIME + +samples = initialize_2D_spiral(Nc=48, Ns=600, nb_revolutions=10) +t_read = np.arange(samples.shape[1]) * DEFAULT_RASTER_TIME * 1e-3 +t_read = np.repeat(t_read[None, ...], samples.shape[0], axis=0) +density = voronoi(samples) + +display_2D_trajectory(samples) + +# %% +# Setup the Operator +# ================== + +from mrinufft import get_operator +from mrinufft.operators.off_resonance import MRIFourierCorrected + +# Generate standard NUFFT operator +nufft = get_operator("finufft")( + samples=samples, + shape=mri_data.shape, + density=density, +) + +# Generate Fourier Corrected operator +mfi_nufft = MRIFourierCorrected( + nufft, b0_map=b0map, readout_time=t_read, mask=brain_mask +) + +# Generate K-Space +kspace = mfi_nufft.op(mri_data) + +# Reconstruct without field correction +mri_data_adj = nufft.adj_op(kspace) +mri_data_adj = np.squeeze(abs(mri_data_adj)) + +# Reconstruct with field correction +mri_data_adj_mfi = mfi_nufft.adj_op(kspace) +mri_data_adj_mfi = np.squeeze(abs(mri_data_adj_mfi)) + +fig2, ax2 = plt.subplots(1, 2) +ax2[0].imshow(mri_data_adj), ax2[0].axis("off"), ax2[0].set_title("w/o correction") +ax2[1].imshow(mri_data_adj_mfi), ax2[1].axis("off"), ax2[1].set_title("with correction") + +plt.show() + +# %% +# The blurring is significantly reduced using the Off-resonance Corrected +# operator (right) diff --git a/pyproject.toml b/pyproject.toml index 143284e71..bfe9bcd7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ dynamic = ["version"] gpunufft = ["gpuNUFFT>=0.9.0", "cupy-cuda12x"] torchkbnufft = ["torchkbnufft", "cupy-cuda12x"] -cufinufft = ["cufinufft", "cupy-cuda12x"] +cufinufft = ["cufinufft<2.3", "cupy-cuda12x"] finufft = ["finufft"] pynfft = ["pynfft2>=1.4.3", "numpy>=2.0.0"] pynufft = ["pynufft"] diff --git a/src/mrinufft/__init__.py b/src/mrinufft/__init__.py index 99ddf4a80..69bae1a26 100644 --- a/src/mrinufft/__init__.py +++ b/src/mrinufft/__init__.py @@ -10,6 +10,7 @@ get_operator, check_backend, list_backends, + get_interpolators_from_fieldmap, ) from .trajectories import ( @@ -56,6 +57,7 @@ "get_operator", "check_backend", "list_backends", + "get_interpolators_from_fieldmap", "initialize_2D_radial", "initialize_2D_spiral", "initialize_2D_fibonacci_spiral", diff --git a/src/mrinufft/extras/__init__.py b/src/mrinufft/extras/__init__.py index 6ba1d4a4d..3c08cb889 100644 --- a/src/mrinufft/extras/__init__.py +++ b/src/mrinufft/extras/__init__.py @@ -1,10 +1,13 @@ """Sensitivity map estimation methods.""" +from .data import make_b0map, make_t2smap from .smaps import low_frequency from .utils import get_smaps __all__ = [ + "make_b0map", + "make_t2smap", "low_frequency", "get_smaps", ] diff --git a/src/mrinufft/extras/data.py b/src/mrinufft/extras/data.py new file mode 100644 index 000000000..5dfcf281e --- /dev/null +++ b/src/mrinufft/extras/data.py @@ -0,0 +1,106 @@ +"""Field map generator module.""" + +import numpy as np + + +def make_b0map(shape, b0range=(-300, 300), mask=None): + """ + Make radial B0 map. + + Parameters + ---------- + shape : tuple[int] + Matrix size. Only supports isotropic matrices. + b0range : tuple[float], optional + Frequency shift range in [Hz]. The default is (-300, 300). + mask : np.ndarray + Spatial support of the objec. If not provided, + build a radial mask with radius = 0.3 * shape + + Returns + ------- + np.ndarray + B0 map of shape (*shape) in [Hz], + with values included in (*b0range). + mask : np.ndarray, optional + Spatial support binary mask. + + """ + assert np.unique(shape).size, ValueError("Only isotropic matriex are supported.") + ndim = len(shape) + if ndim == 2: + radial_mask, fieldmap = _make_disk(shape) + elif ndim == 3: + radial_mask, fieldmap = _make_sphere(shape) + if mask is None: + mask = radial_mask + + # build map + fieldmap *= mask + fieldmap = (b0range[1] - b0range[0]) * fieldmap / fieldmap.max() + b0range[0] # Hz + fieldmap *= mask + + # remove nan + fieldmap = np.nan_to_num(fieldmap, neginf=0.0, posinf=0.0) + + return fieldmap.astype(np.float32), mask + + +def make_t2smap(shape, t2svalue=15.0, mask=None): + """ + Make homogeneous T2* map. + + Parameters + ---------- + shape : tuple[int] + Matrix size. + t2svalue : float, optional + Object T2* in [ms]. The default is 15.0. + mask : np.ndarray + Spatial support of the objec. If not provided, + build a radial mask with radius = 0.3 * shape + + Returns + ------- + np.ndarray + T2* map of shape (*shape) in [ms]. + mask : np.ndarray, optional + Spatial support binary mask. + + """ + assert np.unique(shape).size, ValueError("Only isotropic matriex are supported.") + ndim = len(shape) + if ndim == 2: + radial_mask, fieldmap = _make_disk(shape) + elif ndim == 3: + radial_mask, fieldmap = _make_sphere(shape) + if mask is None: + mask = radial_mask + + # build map + fieldmap = t2svalue * mask # ms + + # remove nan + fieldmap = np.nan_to_num(fieldmap, neginf=0.0, posinf=0.0) + + return fieldmap.astype(np.float32), mask + + +def _make_disk(shape, frac_radius=0.3): + """Make circular binary mask.""" + ny, nx = shape + yy, xx = np.mgrid[:ny, :nx] + yy, xx = yy - ny // 2, xx - nx // 2 + yy, xx = yy / ny, xx / nx + rr = (xx**2 + yy**2) ** 0.5 + return rr < frac_radius, rr + + +def _make_sphere(shape, frac_radius=0.3): + """Make spherical binary mask.""" + nz, ny, nx = shape + zz, yy, xx = np.mgrid[:nz, :ny, :nx] + zz, yy, xx = zz - nz // 2, yy - ny // 2, xx - nx // 2 + zz, yy, xx = zz / nz, yy / ny, xx / nx + rr = (xx**2 + yy**2 + zz**2) ** 0.5 + return rr < frac_radius, rr diff --git a/src/mrinufft/operators/__init__.py b/src/mrinufft/operators/__init__.py index 6cb0bdabb..c642d47d8 100644 --- a/src/mrinufft/operators/__init__.py +++ b/src/mrinufft/operators/__init__.py @@ -10,7 +10,7 @@ list_backends, check_backend, ) -from .off_resonnance import MRIFourierCorrected +from .off_resonance import MRIFourierCorrected, get_interpolators_from_fieldmap from .stacked import MRIStackedNUFFT # @@ -28,4 +28,5 @@ "check_backend", "get_operator", "list_backends", + "get_interpolators_from_fieldmap", ] diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index 0c86e8bbb..1fe79e4a2 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -352,9 +352,9 @@ def data_consistency(self, image, obs_data): """ return self.adj_op(self.op(image) - obs_data) - def with_off_resonnance_correction(self, B, C, indices): + def with_off_resonance_correction(self, B, C, indices): """Return a new operator with Off Resonnance Correction.""" - from ..off_resonnance import MRIFourierCorrected + from ..off_resonance import MRIFourierCorrected return MRIFourierCorrected(self, B, C, indices) diff --git a/src/mrinufft/operators/off_resonance.py b/src/mrinufft/operators/off_resonance.py new file mode 100644 index 000000000..61c6a6e9e --- /dev/null +++ b/src/mrinufft/operators/off_resonance.py @@ -0,0 +1,428 @@ +"""Off Resonance correction Operator wrapper. + +Based on the implementation of Guillaume Daval-Frérot in pysap-mri: +https://github.com/CEA-COSMIC/pysap-mri/blob/master/mri/operators/fourier/orc_wrapper.py +""" + +import math +import numpy as np + +from .._utils import get_array_module + +from .base import FourierOperatorBase, CUPY_AVAILABLE, AUTOGRAD_AVAILABLE +from .interfaces.utils import is_cuda_array + +if CUPY_AVAILABLE: + import cupy as cp + +if AUTOGRAD_AVAILABLE: + import torch + + +def get_interpolators_from_fieldmap( + b0_map, readout_time, n_time_segments=6, n_bins=(40, 10), mask=None, r2star_map=None +): + r"""Approximate ``exp(-2j*pi*fieldmap*readout_time) ≈ Σ B_n(t)C_n(r)``. + + Here, B_n(t) are n_time_segments temporal coefficients and C_n(r) + are n_time_segments temporal spatial coefficients. + + The matrix B has shape ``(n_time_segments, len(readout_time))`` + and C has shape ``(n_time_segments, *b0_map.shape)``. + + From Sigpy: https://github.com/mikgroup/sigpy + and MIRT (mri_exp_approx.m): https://web.eecs.umich.edu/~fessler/code/ + + Parameters + ---------- + b0_map : np.ndarray + Static field inhomogeneities map. + ``b0_map`` and ``readout_time`` should have reciprocal units. + Also supports Cupy arrays and Torch tensors. + readout_time : np.ndarray + Readout time in ``[s]`` of shape ``(n_shots, n_pts)`` or ``(n_shots * n_pts,)``. + Also supports Cupy arrays and Torch tensors. + n_time_segments : int, optional + Number of time segments. The default is ``6``. + n_bins : int | Sequence[int] optional + Number of histogram bins to use for ``(B0, T2*)``. The default is ``(40, 10)`` + If it is a scalar, assume ``n_bins = (n_bins, 10)``. + For real fieldmap (B0 only), ``n_bins[1]`` is ignored. + mask : np.ndarray, optional + Boolean mask of the region of interest + (e.g., corresponding to the imaged object). + This is used to exclude the background fieldmap values + from histogram computation. Must have same shape as ``b0_map``. + The default is ``None`` (use the whole map). + Also supports Cupy arrays and Torch tensors. + r2star_map : np.ndarray, optional + Effective transverse relaxation map (R2*). + ``r2star_map`` and ``readout_time`` should have reciprocal units. + Must have same shape as ``b0_map``. + The default is ``None`` (purely imaginary field). + Also supports Cupy arrays and Torch tensors. + + Notes + ----- + The total field map used to calculate the field coefficients is + ``field_map = R2*_map + 1j * B0_map``. If R2* is not provided, + the field is purely immaginary: ``field_map = 1j * B0_map``. + + Returns + ------- + B : np.ndarray + Temporal interpolator of shape ``(n_time_segments, len(t))``. + Array module is the same as input field_map. + tl : np.ndarray + Time segment centers of shape ``(n_time_segments,)``. + Array module is the same as input field_map. + + """ + # default + if not isinstance(n_bins, (list, tuple)): + n_bins = (n_bins, 10) + n_bins = list(n_bins) + + # get backend and device + xp = get_array_module(b0_map) + + # cast arrays to fieldmap backend + is_torch = xp.__name__ == "torch" + + if is_cuda_array(b0_map): + assert CUPY_AVAILABLE, "GPU computation requires Cupy!" + xp = cp + b0_map = _to_cupy(b0_map) + readout_time = _to_cupy(readout_time) + mask = _to_cupy(mask) + r2star_map = _to_cupy(r2star_map) + else: + xp = np + b0_map = _to_numpy(b0_map) + readout_time = _to_numpy(readout_time) + mask = _to_numpy(mask) + r2star_map = _to_numpy(r2star_map) + + readout_time = xp.asarray(readout_time, dtype=xp.float32).ravel() + if mask is None: + mask = xp.ones_like(b0_map, dtype=bool) + else: + mask = xp.asarray(mask, dtype=bool) + + # Hz to radians / s + field_map = _get_complex_fieldmap(b0_map, r2star_map) + + # enforce precision + field_map = xp.asarray(field_map, dtype=xp.complex64) + + # create histograms + z = field_map[mask].ravel() + + if r2star_map is not None: + z = xp.stack((z.imag, z.real), axis=1) + hk, ze = xp.histogramdd(z, bins=n_bins) + ze = list(ze) + + # get bin centers + zc = [e[1:] - (e[1] - e[0]) / 2 for e in ze] + + # complexify + zk = _outer_sum(1j * zc[0], zc[1]) # [K1 K2] + zk = zk.T + hk = hk.T + else: + hk, ze = xp.histogram(z.imag, bins=n_bins[0]) + + # get bin centers + zc = ze[1:] - (ze[1] - ze[0]) / 2 + + # complexify + zk = 1j * zc # [K 1] + + # flatten histogram values and centers + hk = hk.ravel() + zk = zk.ravel() + + # generate time for each segment + tl = xp.linspace( + readout_time.min(), readout_time.max(), n_time_segments, dtype=xp.float32 + ) # time seg centers in [s] + + # prepare for basis calculation + ch = xp.exp(-tl[:, None, ...] @ zk[None, ...]) + w = xp.diag(hk**0.5) + p = xp.linalg.pinv(w @ ch.T) @ w + + # actual temporal basis calculation + B = p @ xp.exp(-zk[:, None, ...] * readout_time[None, ...]) + B = B.astype(xp.complex64) + + # back to torch if required + if is_torch: + B = _to_torch(B) + tl = _to_torch(tl) + + return B, tl + + +def _outer_sum(xx, yy): + xx = xx[:, None, ...] # add a singleton dimension at axis 1 + yy = yy[None, ...] # add a singleton dimension at axis 0 + ss = xx + yy # compute the outer sum + return ss + + +# TODO: /* refactor with_* decorators +def _to_numpy(input): + if input is None: + return input + xp = get_array_module(input) + + if xp.__name__ == "torch": + return input.numpy(force=True) + elif xp.__name__ == "cupy": + return input.get() + else: + return input + + +def _to_cupy(input): + if input is None: + return input + return cp.asarray(input) + + +def _to_torch(input): + xp = get_array_module(input) + + if xp.__name__ == "numpy": + return torch.from_numpy(input) + elif xp.__name__ == "cupy": + return torch.from_dlpack(input) + else: + return input + + +# */ + + +class MRIFourierCorrected(FourierOperatorBase): + """Fourier Operator with B0 Inhomogeneities compensation. + + This is a wrapper around the Fourier Operator to compensate for the + B0 inhomogeneities in the k-space. + + Parameters + ---------- + b0_map : np.ndarray + Static field inhomogeneities map. + ``b0_map`` and ``readout_time`` should have reciprocal units. + Also supports Cupy arrays and Torch tensors. + readout_time : np.ndarray + Readout time in ``[s]`` of shape ``(n_shots, n_pts)`` or ``(n_shots * n_pts,)``. + Also supports Cupy arrays and Torch tensors. + n_time_segments : int, optional + Number of time segments. The default is ``6``. + n_bins : int | Sequence[int] optional + Number of histogram bins to use for ``(B0, T2*)``. The default is ``(40, 10)`` + If it is a scalar, assume ``n_bins = (n_bins, 10)``. + For real fieldmap (B0 only), ``n_bins[1]`` is ignored. + mask : np.ndarray, optional + Boolean mask of the region of interest + (e.g., corresponding to the imaged object). + This is used to exclude the background fieldmap values + from histogram computation. + The default is ``None`` (use the whole map). + Also supports Cupy arrays and Torch tensors. + B : np.ndarray, optional + Temporal interpolator of shape ``(n_time_segments, len(readout_time))``. + tl : np.ndarray, optional + Time segment centers of shape ``(n_time_segments,)``. + Also supports Cupy arrays and Torch tensors. + r2star_map : np.ndarray, optional + Effective transverse relaxation map (R2*). + ``r2star_map`` and ``readout_time`` should have reciprocal units. + Must have same shape as ``b0_map``. + The default is ``None`` (purely imaginary field). + Also supports Cupy arrays and Torch tensors. + backend: str, optional + The backend to use for computations. Either 'cpu', 'gpu' or 'torch'. + The default is `cpu`. + + Notes + ----- + The total field map used to calculate the field coefficients is + ``field_map = R2*_map + 1j * B0_map``. If R2* is not provided, + the field is purely immaginary: ``field_map = 1j * B0_map``. + + """ + + def __init__( + self, + fourier_op, + b0_map=None, + readout_time=None, + n_time_segments=6, + n_bins=(40, 10), + mask=None, + r2star_map=None, + B=None, + tl=None, + backend="cpu", + ): + if backend == "gpu" and not CUPY_AVAILABLE: + raise RuntimeError("Cupy is required for gpu computations.") + if backend == "torch": + self.xp = torch + if backend == "gpu": + self.xp = cp + elif backend == "cpu": + self.xp = np + else: + raise ValueError("Unsupported backend.") + self._fourier_op = fourier_op + + self.n_coils = fourier_op.n_coils + self.shape = fourier_op.shape + self.smaps = fourier_op.smaps + self.autograd_available = fourier_op.autograd_available + + if B is not None and tl is not None: + self.B = self.xp.asarray(B) + self.tl = self.xp.asarray(tl) + else: + b0_map = self.xp.asarray(b0_map) + self.B, self.tl = get_interpolators_from_fieldmap( + b0_map, + readout_time, + n_time_segments, + n_bins, + mask, + r2star_map, + ) + if self.B is None or self.tl is None: + raise ValueError("Please either provide fieldmap and t or B and tl") + self.n_interpolators = self.B.shape[0] + + # create spatial interpolator + field_map = _get_complex_fieldmap(b0_map, r2star_map) + if is_cuda_array(b0_map): + self.C = None + self.field_map = field_map + else: + self.C = _get_spatial_coefficients(field_map, self.tl) + self.field_map = None + + def op(self, data, *args): + """Compute Forward Operation with off-resonance effect. + + Parameters + ---------- + x: numpy.ndarray + N-D input image. + Also supports Cupy arrays and Torch tensors. + + Returns + ------- + numpy.ndarray + Masked distorted N-D k-space. + Array module is the same as input data. + + """ + y = 0.0 + data_d = self.xp.asarray(data) + if self.C is not None: + for idx in range(self.n_interpolators): + y += self.B[idx] * self._fourier_op.op(self.C[idx] * data_d, *args) + else: + for idx in range(self.n_interpolators): + C = self.xp.exp(-self.field_map * self.tl[idx].item()) + y += self.B[idx] * self._fourier_op.op(C * data_d, *args) + + return y + + def adj_op(self, coeffs, *args): + """ + Compute Adjoint Operation with off-resonance effect. + + Parameters + ---------- + x: numpy.ndarray + Masked distorted N-D k-space. + Also supports Cupy arrays and Torch tensors. + + + Returns + ------- + numpy.ndarray + Inverse Fourier transform of the distorted input k-space. + Array module is the same as input coeffs. + + """ + y = 0.0 + coeffs_d = self.xp.array(coeffs) + if self.C is not None: + for idx in range(self.n_interpolators): + y += self.xp.conj(self.C[idx]) * self._fourier_op.adj_op( + self.xp.conj(self.B[idx]) * coeffs_d, *args + ) + else: + for idx in range(self.n_interpolators): + C = self.xp.exp(-self.field_map * self.tl[idx].item()) + y += self.xp.conj(C) * self._fourier_op.adj_op( + self.xp.conj(self.B[idx]) * coeffs_d, *args + ) + + return y + + @staticmethod + def get_spatial_coefficients(field_map, tl): + """Compute spatial coefficients for field approximation. + + Parameters + ---------- + field_map : np.ndarray + Total field map used to calculate the field coefficients is + ``field_map = R2*_map + 1j * B0_map``. + Also supports Cupy arrays and Torch tensors. + tl : np.ndarray + Time segment centers of shape ``(n_time_segments,)``. + Also supports Cupy arrays and Torch tensors. + + Returns + ------- + C : np.ndarray + Off-resonance phase map at each time segment center of shape + ``(n_time_segments, *field_map.shape)``. + Array module is the same as input field_map. + + """ + return _get_spatial_coefficients(field_map, tl) + + +def _get_complex_fieldmap(b0_map, r2star_map=None): + xp = get_array_module(b0_map) + + if r2star_map is not None: + r2star_map = xp.asarray(r2star_map, dtype=xp.float32) + field_map = 2 * math.pi * (r2star_map + 1j * b0_map) + else: + field_map = 2 * math.pi * 1j * b0_map + + return field_map + + +def _get_spatial_coefficients(field_map, tl): + xp = get_array_module(field_map) + + # get spatial coeffs + C = xp.exp(-tl * field_map[..., None]) + C = C[None, ...].swapaxes(0, -1)[ + ..., 0 + ] # (..., n_time_segments) -> (n_time_segments, ...) + C = xp.asarray(C, dtype=xp.complex64) + + # clean-up of spatial coeffs + C = xp.nan_to_num(C, nan=0.0, posinf=0.0, neginf=0.0) + + return C diff --git a/src/mrinufft/operators/off_resonnance.py b/src/mrinufft/operators/off_resonnance.py deleted file mode 100644 index 0c6f452cf..000000000 --- a/src/mrinufft/operators/off_resonnance.py +++ /dev/null @@ -1,121 +0,0 @@ -"""Off Resonance correction Operator wrapper. - -Based on the implementation of Guillaume Daval-Frérot in pysap-mri: -https://github.com/CEA-COSMIC/pysap-mri/blob/master/mri/operators/fourier/orc_wrapper.py -""" - -import numpy as np - -from .base import FourierOperatorBase -from .interfaces.utils import is_cuda_array - -CUPY_AVAILABLE = True -try: - import cupy as cp -except ImportError: - CUPY_AVAILABLE = False - - -class MRIFourierCorrected(FourierOperatorBase): - """Fourier Operator with B0 Inhomogeneities compensation. - - This is a wrapper around the Fourier Operator to compensate for the - B0 inhomogeneities in the k-space. - - Parameters - ---------- - fourier_op: object of class FourierBase - the fourier operator to wrap - B: numpy.ndarray - C: numpy.ndarray - indices: numpy.ndarray - backend: str, default 'cpu' - the backend to use for computations. Either 'cpu' or 'gpu'. - """ - - def __init__(self, fourier_op, B, C, indices, backend="cpu"): - if backend == "gpu" and not CUPY_AVAILABLE: - raise RuntimeError("Cupy is required for gpu computations.") - if backend == "gpu": - self.xp = cp - elif backend == "cpu": - self.xp = np - else: - raise ValueError("Unsupported backend.") - self._fourier_op = fourier_op - - if not fourier_op.uses_sense: - raise ValueError("please use smaps.") - - self.n_samples = fourier_op.n_samples - self.n_coils = fourier_op.n_coils - self.shape = fourier_op.shape - self.smaps = fourier_op.smaps - self.n_interpolators = len(C) - self.B = self.xp.array(B) - self.B = self.xp.tile(self.B, (self._fourier_op.n_samples // len(B), 1)) - self.C = self.xp.array(C) - self.indices = indices - - def op(self, data, *args): - """Compute Forward Operation with off-resonnances effect. - - Parameters - ---------- - x: numpy.ndarray or cupy.ndarray - N-D input image - - Returns - ------- - numpy.ndarray or cupy.ndarray - masked distorded N-D k-space - """ - y = self.xp.zeros((self.n_coils, self.n_samples), dtype=np.complex64) - data_d = self.xp.asarray(data) - for idx in range(self.n_interpolators): - y += self.B[..., idx] * self._fourier_op.op( - self.C[idx, self.indices] * data_d, *args - ) - if self.xp.__name__ == "cupy" and is_cuda_array(data): - return y - return y.get() - - def adj_op(self, coeffs, *args): - """ - Compute Adjoint Operation with off-resonnance effect. - - Parameters - ---------- - x: numpy.ndarray or cupy.ndarray - masked distorded N-D k-space - - Returns - ------- - inverse Fourier transform of the distorded input k-space. - """ - y = self.xp.zeros(self.shape, dtype=np.complex64) - coeffs_d = self.xp.array(coeffs) - for idx in range(self.n_interpolators): - y += cp.conj(self.C[idx, self.indices]) * self._fourier_op.adj_op( - cp.conj(self.B[..., idx]) * coeffs_d, *args - ) - if self.xp.__name__ == "cupy" and is_cuda_array(coeffs): - return y - return y.get() - - def get_grad(self, image_data, obs_data): - """Compute the data consistency error. - - Parameters - ---------- - image_data: numpy.ndarray or cupy.ndarray - N-D input image - obs_data: numpy.ndarray or cupy.ndarray - N-D observed k-space - - Returns - ------- - numpy.ndarray or cupy.ndarray - data consistency error in image space. - """ - return self.adj_op(self.op(image_data) - obs_data) diff --git a/tests/case_fieldmaps.py b/tests/case_fieldmaps.py new file mode 100644 index 000000000..cd36dff6f --- /dev/null +++ b/tests/case_fieldmaps.py @@ -0,0 +1,57 @@ +"""Fieldmap cases we want to test.""" + +from mrinufft.extras import make_b0map, make_t2smap + + +class CasesB0maps: + """B0 field maps cases we want to test. + + Each case return a field map and the binary spatial support of the object. + """ + + def case_real2D(self, N=64, b0range=(-300, 300)): + """Create a real (B0 only) 2D field map.""" + return make_b0map(2 * [N]) + + def case_real3D(self, N=64, b0range=(-300, 300)): + """Create a real (B0 only) 3D field map.""" + return make_b0map(3 * [N]) + + +class CasesZmaps: + """Complex zmap field maps cases we want to test. + + Each case return a field map and the binary spatial support of the object. + """ + + def case_complex2D(self, N=64, b0range=(-300, 300), t2svalue=15.0): + """Create a complex (R2* + 1j * B0) 2D field map.""" + # Generate real and imaginary parts + t2smap, _ = make_t2smap(2 * [N]) + b0map, mask = make_b0map(2 * [N]) + + # Convert T2* map to R2* map + t2smap = t2smap * 1e-3 # ms -> s + r2smap = 1.0 / (t2smap + 1e-9) # Hz + r2smap = mask * r2smap + + # Calculate complex fieldmap (Zmap) + zmap = r2smap + 1j * b0map + + return zmap, mask + + def case_complex3D(self, N=64, b0range=(-300, 300), t2svalue=15.0): + """Create a complex (R2* + 1j * B0) 3D field map.""" + # Generate real and imaginary parts + t2smap, _ = make_t2smap(3 * [N]) + b0map, mask = make_b0map(3 * [N]) + + # Convert T2* map to R2* map + t2smap = t2smap * 1e-3 # ms -> s + r2smap = 1.0 / (t2smap + 1e-9) # Hz + r2smap = mask * r2smap + + # Calculate complex fieldmap (Zmap) + zmap = r2smap + 1j * b0map + + return zmap, mask diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py index 6afe23f0e..555cd224f 100644 --- a/tests/helpers/__init__.py +++ b/tests/helpers/__init__.py @@ -1,6 +1,6 @@ """Helper functions for testing the operators.""" -from .asserts import assert_almost_allclose, assert_correlate +from .asserts import assert_almost_allclose, assert_correlate, assert_allclose from .factories import ( kspace_from_op, image_from_op, @@ -14,7 +14,7 @@ __all__ = [ "assert_almost_allclose", "assert_correlate", - "kspace_from_op", + "assert_allclose" "kspace_from_op", "image_from_op", "to_interface", "from_interface", diff --git a/tests/helpers/asserts.py b/tests/helpers/asserts.py index ba3ed07f9..c74a015f5 100644 --- a/tests/helpers/asserts.py +++ b/tests/helpers/asserts.py @@ -4,6 +4,8 @@ import numpy.testing as npt import scipy as sp +from .factories import from_interface + def assert_almost_allclose(a, b, rtol, atol, mismatch, equal_nan=False): """Assert allclose with a tolerance on the number of mismatched elements. @@ -64,3 +66,10 @@ def assert_correlate(a, b, slope=1.0, slope_err=1e-3, r_value_err=1e-3): f"intercept={intercept}, stderr={stderr}, " f"intercept_stderr={intercept_stderr}" ) + + +def assert_allclose(actual, expected, atol, rtol, interface): + """Backend agnostic assertion using from_interface helper.""" + actual_np = from_interface(actual, interface) + expected_np = from_interface(expected, interface) + npt.assert_allclose(actual_np, expected_np, atol=atol, rtol=rtol) diff --git a/tests/operators/test_offres_exp_approx.py b/tests/operators/test_offres_exp_approx.py new file mode 100644 index 000000000..dfcc03f45 --- /dev/null +++ b/tests/operators/test_offres_exp_approx.py @@ -0,0 +1,108 @@ +"""Test off-resonance spatial coefficient and temporal interpolator estimation.""" + +import math + +import numpy as np + +import pytest +from pytest_cases import parametrize_with_cases + + +import mrinufft +from mrinufft._utils import get_array_module +from mrinufft.operators.base import CUPY_AVAILABLE +from mrinufft.operators.off_resonance import MRIFourierCorrected + + +from helpers import to_interface, assert_allclose +from helpers.factories import _param_array_interface +from case_fieldmaps import CasesB0maps, CasesZmaps + + +def calculate_true_offresonance_term(fieldmap, t, array_interface): + """Calculate non-approximate off-resonance modulation term.""" + fieldmap = to_interface(fieldmap, array_interface) + t = to_interface(t, array_interface) + + xp = get_array_module(fieldmap) + arg = t * fieldmap[..., None] + arg = arg[None, ...].swapaxes(0, -1)[..., 0] + return xp.exp(-arg) + + +def calculate_approx_offresonance_term(B, C): + """Calculate approximate off-resonance modulation term.""" + field_term = 0.0 + for n in range(B.shape[0]): + tmp = B[n] * C[n][..., None] + tmp = tmp[None, ...].swapaxes(0, -1)[..., 0] + field_term += tmp + return field_term + + +@_param_array_interface +@parametrize_with_cases("b0map, mask", cases=CasesB0maps) +def test_b0map_coeff(b0map, mask, array_interface): + """Test exponential approximation for B0 field only.""" + if array_interface == "torch-gpu" and not CUPY_AVAILABLE: + pytest.skip("GPU computations requires cupy") + + # Generate readout times + tread = np.linspace(0.0, 5e-3, 501, dtype=np.float32) + + # Generate coefficients + B, tl = mrinufft.get_interpolators_from_fieldmap( + to_interface(b0map, array_interface), tread, mask=mask, n_time_segments=100 + ) + + # Calculate spatial coefficients + C = MRIFourierCorrected.get_spatial_coefficients( + to_interface(2 * math.pi * 1j * b0map, array_interface), tl + ) + + # Assert properties + assert B.shape == (100, 501) + assert C.shape == (100, *b0map.shape) + + # Correct approximation + expected = calculate_true_offresonance_term( + 0 + 2 * math.pi * 1j * b0map, tread, array_interface + ) + actual = calculate_approx_offresonance_term(B, C) + assert_allclose(actual, expected, atol=1e-3, rtol=1e-3, interface=array_interface) + + +@_param_array_interface +@parametrize_with_cases("zmap, mask", cases=CasesZmaps) +def test_zmap_coeff(zmap, mask, array_interface): + """Test exponential approximation for complex Z = R2* + 1j *B0 field.""" + if array_interface == "torch-gpu" and CUPY_AVAILABLE is False: + pytest.skip("GPU computations requires cupy") + + # Generate readout times + tread = np.linspace(0.0, 5e-3, 501, dtype=np.float32) + + # Generate coefficients + B, tl = mrinufft.get_interpolators_from_fieldmap( + to_interface(zmap.imag, array_interface), + tread, + mask=mask, + r2star_map=to_interface(zmap.real, array_interface), + n_time_segments=100, + ) + + # Calculate spatial coefficients + C = MRIFourierCorrected.get_spatial_coefficients( + to_interface(2 * math.pi * zmap, array_interface), tl + ) + + # Assert properties + assert B.shape == (100, 501) + assert C.shape == (100, *zmap.shape) + + # Correct approximation + expected = calculate_true_offresonance_term( + 2 * math.pi * zmap, tread, array_interface + ) + actual = calculate_approx_offresonance_term(B, C) + assert_allclose(actual, expected, atol=1e-3, rtol=1e-3, interface=array_interface)