Skip to content

Commit

Permalink
fix: various bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
mcencini committed Apr 12, 2024
1 parent 64055fa commit bb9fab9
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 23 deletions.
9 changes: 5 additions & 4 deletions src/deepmr/_signal/subspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,17 @@ def svd(input, ncoeff, axis):
output = output.reshape(nrows, ncols)

# perform svd
u, s, vh = torch.linalg.svd(output, full_matrices=None)
u, s, vh = torch.linalg.svd(output, full_matrices=False)

# compress data
basis = vh[..., :ncoeff]
v = vh.conj().t()
basis = v[..., :ncoeff]
output = output @ basis

# calculate explained variance
explained_variance = s**2 / (nrows - 1) # (neigenvalues,)
explained_variance = explained_variance / explained_variance.sum()
explained_variance = torch.cumsum(explained_variance)[ncoeff - 1]
explained_variance = torch.cumsum(explained_variance, 0)

# reshape
output = output.reshape(*ishape[:-1], ncoeff)
Expand All @@ -141,4 +142,4 @@ def svd(input, ncoeff, axis):
output = output.numpy()
basis = basis.numpy()

return basis, output, 100 * explained_variance
return basis, output, 100 * explained_variance[ncoeff - 1], 100 * explained_variance
5 changes: 3 additions & 2 deletions src/deepmr/fft/_interp/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def apply_gridding(data_in, interpolator, basis=None, device=None, threadsperblo
data_out = torch.zeros(
(ncoeff, batch_size, *dshape), dtype=data_in.dtype, device=device
)
# print(basis.shape)

# do actual gridding
if device == "cpu":
Expand Down Expand Up @@ -329,11 +330,11 @@ def _gridding_lowrank2(
ncoeff, batch_size, _, _ = cart_data.shape
nframes = noncart_data.shape[0]
npts = noncart_data.shape[-1]

# unpack interpolator
yindex, xindex = interp_index
yvalue, xvalue = interp_value

# get interpolator width
ywidth = yindex.shape[-1]
xwidth = xindex.shape[-1]
Expand Down
4 changes: 3 additions & 1 deletion src/deepmr/fft/_interp/toeplitz.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def apply_toeplitz(
# reshape back
data_out = data_out.reshape(*shape)
else:
data_out = toeplitz_kernel.value * data_in
kvalue = backend.numba2pytorch(toeplitz_kernel.value)
data_out = kvalue * data_in
toeplitz_kernel.value = backend.pytorch2numba(kvalue)

# collect garbage
gc.collect()
Expand Down
10 changes: 8 additions & 2 deletions src/deepmr/io/header/matlab.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def read_matlab_acqhead(filepath, dcfpath=None, methodpath=None, sliceprofpath=N
ndim = k.shape[-1]

# get adc
adc = _get_adc(matfile)
adc, readout_length = _get_adc(matfile)

# get dcf
dcf = _get_dcf(matfile, k, filepath, dcfpath)
Expand Down Expand Up @@ -74,6 +74,11 @@ def read_matlab_acqhead(filepath, dcfpath=None, methodpath=None, sliceprofpath=N

# get basis
head = _get_basis(head, matfile)

# get echo time
if "te" in matfile:
head.TE = np.asarray(matfile["te"], dtype=np.float32).squeeze()
head.user["readout_length"] = readout_length * len(np.unique(head.TE))

return head

Expand Down Expand Up @@ -164,8 +169,9 @@ def _get_adc(matfile):
adc = matfile["inds"].squeeze().astype(bool)
else:
raise RuntimeError("ADC indexes not found!")
readout_length = adc.shape[0]
adc = np.argwhere(adc)[[0, -1]].squeeze()
return adc
return adc, readout_length


def _get_dcf(matfile, k, filename, dcfname):
Expand Down
7 changes: 5 additions & 2 deletions src/deepmr/optim/cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def __init__(self, AHA, AHy, ndim=None, tol=None):
self.ndim = AHA.ndim
except Exception:
self.ndim = ndim

# assign operators
self.AHA = AHA
self.AHy = AHy
Expand All @@ -194,6 +194,9 @@ def __init__(self, AHA, AHy, ndim=None, tol=None):
def dot(self, s1, s2):
dot = s1.conj() * s2
dot = dot.reshape(*s1.shape[: -self.ndim], -1).sum(axis=-1)
if np.isscalar(dot) is False:
for n in range(self.ndim):
dot = dot[..., None]

return dot

Expand All @@ -210,7 +213,7 @@ def forward(self, input):

def check_convergence(self):
if self.tol is not None:
if self.rsnew.sqrt() < self.tol:
if (self.rsnew.sqrt() < self.tol).all():
return True
else:
return False
Expand Down
2 changes: 2 additions & 0 deletions src/deepmr/recon/alg/classic_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def recon_lstsq(
nsets=1,
device=None,
cal_data=None,
sensmap=None,
toeplitz=True,
use_dcf=True,
):
Expand Down Expand Up @@ -140,6 +141,7 @@ def recon_lstsq(
basis,
device,
cal_data,
sensmap,
toeplitz,
)

Expand Down
27 changes: 15 additions & 12 deletions src/deepmr/recon/alg/linop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def EncodingOp(
basis=None,
device=None,
cal_data=None,
sensmap=None,
toeplitz=False,
):
"""
Expand Down Expand Up @@ -83,10 +84,11 @@ def EncodingOp(
if ncoils == 1:
return F, FHF
else:
if cal_data is not None:
sensmap, _ = _calib.espirit_cal(cal_data.to(device), nsets=nsets)
else:
sensmap, _ = _calib.espirit_cal(data.to(device), nsets=nsets)
if sensmap is None:
if cal_data is not None:
sensmap, _ = _calib.espirit_cal(cal_data.to(device), nsets=nsets)
else:
sensmap, _ = _calib.espirit_cal(data.to(device), nsets=nsets)

# infer from mask shape whether we are using multicontrast or not
if len(mask.shape) == 2:
Expand Down Expand Up @@ -120,14 +122,15 @@ def EncodingOp(
if ncoils == 1:
return F, FHF
else:
if cal_data is not None:
sensmap, _ = _calib.espirit_cal(
cal_data.to(device), nsets=nsets, coord=traj, shape=shape, dcf=dcf
)
else:
sensmap, _ = _calib.espirit_cal(
data.to(device), nsets=nsets, coord=traj, shape=shape, dcf=dcf
)
if sensmap is None:
if cal_data is not None:
sensmap, _ = _calib.espirit_cal(
cal_data.to(device), nsets=nsets, coord=traj, shape=shape, dcf=dcf
)
else:
sensmap, _ = _calib.espirit_cal(
data.to(device), nsets=nsets, coord=traj, shape=shape, dcf=dcf
)

# infer from mask shape whether we are using multicontrast or not
if len(traj.shape) < 4:
Expand Down

0 comments on commit bb9fab9

Please sign in to comment.