Skip to content

Commit

Permalink
enhance: pure pytorch based grog; todo: nonlinear reconstruction
Browse files Browse the repository at this point in the history
  • Loading branch information
mcencini committed Apr 3, 2024
1 parent e18af1d commit 8f9863b
Show file tree
Hide file tree
Showing 10 changed files with 109 additions and 64 deletions.
2 changes: 1 addition & 1 deletion src/deepmr/fft/_interp/interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def apply_interpolation(
)

# do actual interpolation
if device == "cpu":
if device == "cpu" or device == torch.device("cpu"):
do_interpolation[ndim - 1](data_out, data_in, value, index, basis_adjoint)
else:
do_interpolation_cuda[ndim - 1](
Expand Down
2 changes: 1 addition & 1 deletion src/deepmr/fft/_interp/toeplitz.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def apply_toeplitz(
).contiguous() # (nvoxels, nbatches, ncontrasts)

# actual interpolation
if toeplitz_kernel.device == "cpu":
if toeplitz_kernel.device == "cpu" or toeplitz_kernel.device == torch.device("cpu"):
do_selfadjoint_interpolation(data_out, data_in, toeplitz_kernel.value)
else:
do_selfadjoint_interpolation_cuda(
Expand Down
2 changes: 1 addition & 1 deletion src/deepmr/fft/_sparse/dense2sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def apply_sampling(data_in, mask, basis_adjoint=None, device=None, threadsperblo
)

# do actual sampling
if device == "cpu":
if device == "cpu" or device == torch.device("cpu"):
do_sampling[ndim - 1](data_out, data_in, index, basis_adjoint)
else:
do_sampling_cuda[ndim - 1](
Expand Down
2 changes: 1 addition & 1 deletion src/deepmr/fft/_sparse/sparse2dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def apply_zerofill(data_in, mask, basis=None, device=None, threadsperblock=128):
)

# do actual zerofill
if device == "cpu":
if device == "cpu" or device == torch.device("cpu"):
do_zerofill[ndim - 1](data_out, data_in, index, basis)
else:
do_zerofill_cuda[ndim - 1](data_out, data_in, index, basis, threadsperblock)
Expand Down
2 changes: 1 addition & 1 deletion src/deepmr/fft/_sparse/toeplitz.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def apply_toeplitz(
).contiguous() # (nvoxels, nbatches, ncontrasts)

# actual interpolation
if toeplitz_kernel.device == "cpu":
if toeplitz_kernel.device == "cpu" or toeplitz_kernel.device == torch.device("cpu"):
do_selfadjoint_interpolation(data_out, data_in, toeplitz_kernel.value)
else:
do_selfadjoint_interpolation_cuda(
Expand Down
8 changes: 3 additions & 5 deletions src/deepmr/fft/sparse_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def prepare_sampling(indexes, shape, device="cpu"):
interpolator : dict
Structure containing sparse interpolator matrix:
* index (``torch.Tensor[int]``): indexes of the non-zero entries of interpolator sparse matrix of shape (ndim, ncoord, width).
* value (``torch.Tensor[float32]``): values of the non-zero entries of interpolator sparse matrix of shape (ndim, ncoord, width).
* index (``torch.Tensor[int]``): indexes of the non-zero entries of interpolator sparse matrix of shape (ndim, ncoord).
* dshape (``Iterable[int]``): oversample grid shape of shape (ndim,). Order of axes is (z, y, x).
* ishape (``Iterable[int]``): interpolator shape (ncontrasts, nview, nsamples)
* ndim (``int``): number of spatial dimensions.
Expand Down Expand Up @@ -76,9 +75,8 @@ def plan_toeplitz_fft(coord, shape, basis=None, device="cpu"):
Parameters
----------
coord : torch.Tensor
K-space coordinates of shape ``(ncontrasts, nviews, nsamples, ndims)``.
Coordinates must be normalized between ``(-0.5 * shape[i], 0.5 * shape[i])``,
with ``i = (z, y, x)``.
Sampled k-space locations of shape ``(ncontrasts, nviews, nsamples, ndims)``.
Indexes must be between ``(0, shape[i])``, with ``i = (z, y, x)``.
shape : int | Iterable[int]
Oversampled grid size of shape ``(ndim,)``.
If scalar, isotropic matrix is assumed.
Expand Down
2 changes: 1 addition & 1 deletion src/deepmr/recon/calib/acs.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _find_cart_acs(kspace):
kspace = torch.as_tensor(kspace, device=device, dtype=torch.complex64)

kspace = kspace.swapaxes(0, 1) # (ncoils, nz, ny, nx)
kspace = _fft.fft(kspace, axes=(1))
kspace = _fft.fft(kspace, axes=(1,))

return kspace

Expand Down
33 changes: 26 additions & 7 deletions src/deepmr/recon/calib/grog/_grog_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,45 @@

import deepmr

import torch
import numpy as np

import matplotlib.pyplot as plt

from deepmr.recon.calib.grog import grogop

# define object, trajectory and coils
img0 = deepmr.shepp_logan(128)
smap0 = deepmr.sensmap((8, 128, 128))
head = deepmr.radial((128), nviews=200, osf=2.0)
img0 = deepmr.shepp_logan(256)
smap0 = deepmr.sensmap((32, 256, 256))
head = deepmr.radial((256), nviews=200, osf=2.0)

# nufft recon
ksp = deepmr.fft.nufft(smap0[:, None, ...] * img0, head.traj)
img = deepmr.fft.nufft_adj(head.dcf * ksp, head.traj, head.shape)
ksp = deepmr.fft.nufft(smap0[:, None, ...] * img0, head.traj, oversamp=2.0)
img = deepmr.fft.nufft_adj(head.dcf * ksp, head.traj, head.shape, oversamp=2.0)
img = deepmr.rss(img, axis=0).squeeze()
img = abs(img)

# get sense map and calibration data
smap, cal_data = deepmr.recon.espirit_cal(ksp, head.traj, head.dcf, head.shape)

# get cartesian ksp and indexes
d, indexes, weights = grogop.grog_interp(
ksp, cal_data, head.traj, head.shape, lamda=0.05
ksp, cal_data, head.traj, head.shape, lamda=0.0,
)

# recon
img_grog = deepmr.fft.sparse_ifft(weights * d, indexes, head.shape)
img_grog = deepmr.rss(img_grog, axis=0).squeeze()
img_grog = abs(img_grog)

# normalize
out0 = abs(img)
out = abs(img_grog)

out0 = out0 / np.nanmax(out0)
out = out / np.nanmax(out)

plt.subplot(1,2,1)
plt.imshow(abs(np.concatenate((out0, out), axis=-1)), cmap="gray")
plt.subplot(1,2,2)
plt.imshow(abs(out0 - out), cmap="bwr"), plt.colorbar()

118 changes: 73 additions & 45 deletions src/deepmr/recon/calib/grog/grogop.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@

import torch

import scipy

from ...._utils import backend


def grog_interp(
input, calib, coord, shape, lamda=0.01, nsteps=11, device=None, threadsperblock=128
input, calib, coord, shape, lamda=0.01, nsteps=11, device=None, threadsperblock=128,
):
"""
GRAPPA Operator Gridding (GROP) interpolation of Non-Cartesian datasets.
Expand Down Expand Up @@ -92,8 +91,8 @@ def grog_interp(
ndim = coord.shape[-1]

# get grappa operator
kern = _calc_grappaop(calib, ndim, lamda)

kern = _calc_grappaop(calib, ndim, lamda, device)
# get coord shape
cshape = coord.shape

Expand All @@ -110,40 +109,41 @@ def grog_interp(
) # (nslices, nsamples, ncoils) or (nsamples, ncoils)

# perform product
deltas = (np.arange(nsteps) - (nsteps - 1) // 2) / (nsteps - 1)
deltas = (torch.arange(nsteps) - (nsteps - 1) // 2) / (nsteps - 1)

# get Gx, Gy, Gz
Gx = _weight_grid(kern.Gx, deltas) # (nslices, nsteps, nc, nc)
Gy = _weight_grid(kern.Gy, deltas) # (nslices, nsteps, nc, nc)
Gx = _weight_grid(kern.Gx, deltas) # 2D: (nsteps, nslices, nc, nc); 3D: (nsteps, nc, nc)
Gy = _weight_grid(kern.Gy, deltas) # 2D: (nsteps, nslices, nc, nc); 3D: (nsteps, nc, nc)

if ndim == 3:
Gz = _weight_grid(kern.Gz, deltas) # (nslices, nsteps, nc, nc)
Gz = _weight_grid(kern.Gz, deltas) # (nsteps, nc, nc), 3D only
else:
Gz = None

# build G
if ndim == 2:
Gx = Gx[None, ...]
Gy = Gy[:, None, ...]
Gx = np.repeat(Gx, nsteps, axis=0)
Gy = np.repeat(Gy, nsteps, axis=1)
Gx = Gx.reshape(-1, *Gx.shape[-3:])
Gy = Gy.reshape(-1, *Gy.shape[-3:])
G = Gx @ Gy
Gx = np.repeat(Gx, nsteps, axis=0) # (nsteps, nsteps, nslices, nc, nc)
Gy = np.repeat(Gy, nsteps, axis=1) # (nsteps, nsteps, nslices, nc, nc)
Gx = Gx.reshape(-1, *Gx.shape[-3:]) # (nsteps**2, nslices, nc, nc)
Gy = Gy.reshape(-1, *Gy.shape[-3:]) # (nsteps**2, nslices, nc, nc)
G = Gx @ Gy # (nsteps**2, nslices, nc, nc)
elif ndim == 3:
Gx = Gx[None, None, ...]
Gy = Gy[None, :, None, ...]
Gz = Gz[:, None, None, ...]
Gx = np.repeat(Gx, nsteps, axis=0)
Gx = np.repeat(Gx, nsteps, axis=1)
Gy = np.repeat(Gy, nsteps, axis=0)
Gy = np.repeat(Gy, nsteps, axis=2)
Gz = np.repeat(Gz, nsteps, axis=1)
Gz = np.repeat(Gz, nsteps, axis=2)
Gx = Gx.reshape(-1, *Gx.shape[-2:])
Gy = Gy.reshape(-1, *Gy.shape[-2:])
Gz = Gz.reshape(-1, *Gz.shape[-2:])
G = Gx @ Gy @ Gz

Gx = np.repeat(Gx, nsteps, axis=0) # (nsteps, nsteps, nsteps, nc, nc)
Gx = np.repeat(Gx, nsteps, axis=1) # (nsteps, nsteps, nsteps, nc, nc)
Gy = np.repeat(Gy, nsteps, axis=0) # (nsteps, nsteps, nsteps, nc, nc)
Gy = np.repeat(Gy, nsteps, axis=2) # (nsteps, nsteps, nsteps, nc, nc)
Gz = np.repeat(Gz, nsteps, axis=1) # (nsteps, nsteps, nsteps, nc, nc)
Gz = np.repeat(Gz, nsteps, axis=2) # (nsteps, nsteps, nsteps, nc, nc)
Gx = Gx.reshape(-1, *Gx.shape[-2:]) # (nsteps**3, nc, nc)
Gy = Gy.reshape(-1, *Gy.shape[-2:]) # (nsteps**3, nc, nc)
Gz = Gz.reshape(-1, *Gz.shape[-2:]) # (nsteps**3, nc, nc)
G = Gx @ Gy @ Gz # (nsteps**3, nc, nc)
# build indexes
indexes = torch.round(coord)
lut = indexes - coord
Expand All @@ -156,7 +156,7 @@ def grog_interp(
input = input.swapaxes(0, 1) # (nsamples, nslices, ncoils)

# perform interpolation
if device == "cpu":
if device == "cpu" or device == torch.device("cpu"):
output = do_interpolation(input, G, lut)
else:
output = do_interpolation_cuda(input, G, lut, threadsperblock)
Expand All @@ -179,26 +179,34 @@ def grog_interp(
weights = counts[idx]

# count
# max_value = torch.max(weights)
# counts = torch.bincount(weights, minlength=max_value+1)
# weights = counts[weights]

weights = weights.reshape(*indexes.shape[:-1])
weights = weights.to(torch.float32)
weights = 1 / weights

# finalize data
# finalize
if ndim == 2:
output = output.swapaxes(0, 1) # (nslices, nsamples, ncoils)
output = output.reshape(*dshape)
output = output.swapaxes(-3, -1)
output = output[..., 0]
output = output.reshape(ishape)


# remove out-of-boundaries
shape = list(shape[-ndim:])[::-1] # (x, y, z)
for n in range(ndim):
outside = indexes[..., n] < 0
output[..., outside] = 0.0
indexes[..., n][outside] = 0
outside = indexes[..., n] >= shape[n]
indexes[..., n][outside] = shape[n]-1
output[..., outside] = 0.0

# cast back to original device
output = output.to(idevice)
indexes = indexes.to(idevice)
weights = weights.to(idevice)


# if required, cast back to numpy
if isnumpy:
output = output.numpy(force=True)
indexes = indexes.to(idevice)
Expand All @@ -208,9 +216,9 @@ def grog_interp(


# %% subroutines
def _calc_grappaop(calib, ndim, lamda):
def _calc_grappaop(calib, ndim, lamda, device):
# as Tensor
calib = torch.as_tensor(calib)
calib = torch.as_tensor(calib, device=device)

# expand
if len(calib.shape) == 3: # single slice (nc, ny, nx)
Expand All @@ -223,11 +231,11 @@ def _calc_grappaop(calib, ndim, lamda):
gz, gy, gx = _grappa_op_3d(calib, lamda)

# prepare output
GrappaOp = SimpleNamespace()
GrappaOp.Gx, GrappaOp.Gy = (gx.numpy(force=True), gy.numpy(force=True))
GrappaOp = SimpleNamespace(Gx=gx, Gy=gy)
# GrappaOp.Gx, GrappaOp.Gy = (gx.numpy(force=True), gy.numpy(force=True))

if ndim == 3:
GrappaOp.Gz = gz.numpy(force=True)
GrappaOp.Gz = gz #.numpy(force=True)
else:
GrappaOp.Gz = None

Expand Down Expand Up @@ -305,25 +313,43 @@ def _bdot(a, b):


def _weight_grid(A, weight):
return np.stack([_matrix_power(A, w) for w in weight], axis=0)


def _matrix_power(A, t):
if len(A.shape) == 2:
return scipy.linalg.fractional_matrix_power(A, t)
else:
return np.stack([scipy.linalg.fractional_matrix_power(a, t) for a in A])
# decompose
L, V = torch.linalg.eig(A)

# raise to power along expanded first dim
if len(L.shape) == 2: # 3D case, (nc, nc)
L = L[None, ...]**weight[:, None, None]
else: # 2D case, (nslices, nc, nc)
L = L[None, ...]**weight[:, None, None, None]

# unsqueeze batch dimension for V
V = V[None, ...]

# put together and return
return V @ torch.diag_embed(L) @ torch.linalg.inv(V)

# def _weight_grid(A, weight):
# return np.stack([_matrix_power(A, w) for w in weight], axis=0)


# def _matrix_power(A, t):
# if len(A.shape) == 2:
# return scipy.linalg.fractional_matrix_power(A, t)
# else:
# return np.stack([scipy.linalg.fractional_matrix_power(a, t) for a in A])


def do_interpolation(noncart, G, lut):
cart = torch.zeros(noncart.shape, dtype=noncart.dtype, device=noncart.device)
cart = backend.pytorch2numba(cart)
G = backend.pytorch2numba(G)
noncart = backend.pytorch2numba(noncart)
lut = backend.pytorch2numba(lut)

_interp(cart, noncart, G, lut)

noncart = backend.numba2pytorch(noncart)
G = backend.numba2pytorch(G)
cart = backend.numba2pytorch(cart)
lut = backend.numba2pytorch(lut)

Expand Down Expand Up @@ -371,13 +397,15 @@ def do_interpolation_cuda(noncart, G, lut, threadsperblock):

cart = torch.zeros(noncart.shape, dtype=noncart.dtype, device=noncart.device)
cart = backend.pytorch2numba(cart)
G = backend.pytorch2numba(G)
noncart = backend.pytorch2numba(noncart)
lut = backend.pytorch2numba(lut)

# run kernel
_interp_cuda[blockspergrid, threadsperblock](cart, noncart, G, lut)

noncart = backend.numba2pytorch(noncart)
G = backend.numba2pytorch(G)
cart = backend.numba2pytorch(cart)
lut = backend.numba2pytorch(lut)

Expand Down
2 changes: 1 addition & 1 deletion src/deepmr/recon/calib/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def intensity_scaling(input, ndim):
Input signal of shape ``(..., ny, nx)`` (2D) or
``(..., nz, ny, nx)`` (3D).
ndim : int, optional
PNumber of spatial dimensions.
Number of spatial dimensions.
Returns
-------
Expand Down

0 comments on commit 8f9863b

Please sign in to comment.