Skip to content

Commit 92e3193

Browse files
fzimmermann89pre-commit-ci[bot]schuenke
authored
Move trajectory scaling into KTrajectory (#582)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Patrick Schuenke <patrick.schuenke@gmail.com>
1 parent 4096ead commit 92e3193

File tree

5 files changed

+101
-33
lines changed

5 files changed

+101
-33
lines changed

src/mrpro/data/KTrajectory.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
"""KTrajectory dataclass."""
22

33
from dataclasses import dataclass
4+
from typing import Literal
45

56
import numpy as np
67
import torch
78
from typing_extensions import Self
89

910
from mrpro.data.enums import TrajType
1011
from mrpro.data.MoveDataMixin import MoveDataMixin
12+
from mrpro.data.SpatialDimension import SpatialDimension
1113
from mrpro.utils import remove_repeat
1214
from mrpro.utils.summarize_tensorvalues import summarize_tensorvalues
1315

@@ -69,29 +71,52 @@ def from_tensor(
6971
cls,
7072
tensor: torch.Tensor,
7173
stack_dim: int = 0,
72-
repeat_detection_tolerance: float | None = 1e-8,
74+
axes_order: Literal['zxy', 'zyx', 'yxz', 'yzx', 'xyz', 'xzy'] = 'zyx',
75+
repeat_detection_tolerance: float | None = 1e-6,
7376
grid_detection_tolerance: float = 1e-3,
77+
scaling_matrix: SpatialDimension | None = None,
7478
) -> Self:
7579
"""Create a KTrajectory from a tensor representation of the trajectory.
7680
77-
Reduces repeated dimensions to singletons if repeat_detection_tolerance
78-
is not set to None.
79-
81+
Reduces repeated dimensions to singletons if repeat_detection_tolerance is not set to None.
8082
8183
Parameters
8284
----------
8385
tensor
8486
The tensor representation of the trajectory.
85-
This should be a 5-dim tensor, with (kz,ky,kx) stacked in this order along stack_dim
87+
This should be a 5-dim tensor, with (kz, ky, kx) stacked in this order along `stack_dim`.
8688
stack_dim
87-
The dimension in the tensor the directions have been stacked along.
89+
The dimension in the tensor along which the directions are stacked.
90+
axes_order
91+
The order of the axes in the tensor. The MRpro convention is 'zyx'.
8892
repeat_detection_tolerance
89-
detects if broadcasting can be used, i.e. if dimensions are repeated.
90-
Set to None to disable.
93+
Tolerance for detecting repeated dimensions (broadcasting).
94+
If trajectory points differ by less than this value, they are considered identical.
95+
Set to None to disable this feature.
9196
grid_detection_tolerance
92-
tolerance to detect if trajectory points are on integer grid positions
97+
Tolerance for detecting whether trajectory points align with integer grid positions.
98+
This tolerance is applied after rescaling if `scaling_matrix` is provided.
99+
scaling_matrix
100+
If a scaling matrix is provided, the trajectory is rescaled to fit within
101+
the dimensions of the matrix. If not provided, the trajectory remains unchanged.
102+
93103
"""
94-
kz, ky, kx = torch.unbind(tensor, dim=stack_dim)
104+
ks = tensor.unbind(dim=stack_dim)
105+
kz, ky, kx = (ks[axes_order.index(axis)] for axis in 'zyx')
106+
107+
def rescale(k: torch.Tensor, size: float) -> torch.Tensor:
108+
max_abs_range = 2 * k.abs().max()
109+
if size < 2 or max_abs_range < 1e-6:
110+
# a single encoding point should be at zero
111+
# avoid division by zero
112+
return torch.zeros_like(k)
113+
return k * (size / max_abs_range)
114+
115+
if scaling_matrix is not None:
116+
kz = rescale(kz, scaling_matrix.z)
117+
ky = rescale(ky, scaling_matrix.y)
118+
kx = rescale(kx, scaling_matrix.x)
119+
95120
return cls(
96121
kz,
97122
ky,

src/mrpro/data/KTrajectoryRawShape.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
"""KTrajectoryRawShape dataclass."""
22

33
from dataclasses import dataclass
4+
from typing import Literal
45

56
import numpy as np
67
import torch
78
from einops import rearrange
9+
from typing_extensions import Self
810

911
from mrpro.data.KTrajectory import KTrajectory
1012
from mrpro.data.MoveDataMixin import MoveDataMixin
13+
from mrpro.data.SpatialDimension import SpatialDimension
1114

1215

1316
@dataclass(slots=True, frozen=True)
@@ -32,6 +35,52 @@ class KTrajectoryRawShape(MoveDataMixin):
3235
repeat_detection_tolerance: None | float = 1e-3
3336
"""tolerance for repeat detection. Set to None to disable."""
3437

38+
@classmethod
39+
def from_tensor(
40+
cls,
41+
tensor: torch.Tensor,
42+
stack_dim: int = 0,
43+
axes_order: Literal['zxy', 'zyx', 'yxz', 'yzx', 'xyz', 'xzy'] = 'zyx',
44+
repeat_detection_tolerance: float | None = 1e-6,
45+
scaling_matrix: SpatialDimension | None = None,
46+
) -> Self:
47+
"""Create a KTrajectoryRawShape from a tensor representation of the trajectory.
48+
49+
Parameters
50+
----------
51+
tensor
52+
The tensor representation of the trajectory.
53+
This should be a 5-dim tensor, with (kz, ky, kx) stacked in this order along `stack_dim`.
54+
stack_dim
55+
The dimension in the tensor along which the directions are stacked.
56+
axes_order
57+
The order of the axes in the tensor. The MRpro convention is 'zyx'.
58+
repeat_detection_tolerance
59+
Tolerance for detecting repeated dimensions (broadcasting).
60+
If trajectory points differ by less than this value, they are considered identical.
61+
Set to None to disable this feature.
62+
scaling_matrix
63+
If a scaling matrix is provided, the trajectory is rescaled to fit within
64+
the dimensions of the matrix. If not provided, the trajectory remains unchanged.
65+
"""
66+
ks = tensor.unbind(dim=stack_dim)
67+
kz, ky, kx = (ks[axes_order.index(axis)] for axis in 'zyx')
68+
69+
def rescale(k: torch.Tensor, size: float) -> torch.Tensor:
70+
max_abs_range = 2 * k.abs().max()
71+
if size < 2 or max_abs_range < 1e-6:
72+
# a single encoding point should be at zero
73+
# avoid division by zero
74+
return torch.zeros_like(k)
75+
return k * (size / max_abs_range)
76+
77+
if scaling_matrix is not None:
78+
kz = rescale(kz, scaling_matrix.z)
79+
ky = rescale(ky, scaling_matrix.y)
80+
kx = rescale(kx, scaling_matrix.x)
81+
82+
return cls(kz, ky, kx, repeat_detection_tolerance=repeat_detection_tolerance)
83+
3584
def sort_and_reshape(
3685
self,
3786
sort_idx: np.ndarray,

src/mrpro/data/traj_calculators/KTrajectoryPulseq.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from pathlib import Path
44

5-
import pypulseq as pp
65
import torch
76
from einops import rearrange
87

@@ -40,8 +39,10 @@ def __call__(self, kheader: KHeader) -> KTrajectoryRawShape:
4039
-------
4140
trajectory of type KTrajectoryRawShape
4241
"""
42+
from pypulseq import Sequence
43+
4344
# create PyPulseq Sequence object and read .seq file
44-
seq = pp.Sequence()
45+
seq = Sequence()
4546
seq.read(file_path=str(self.seq_path))
4647

4748
# calculate k-space trajectory using PyPulseq
@@ -52,20 +53,11 @@ def __call__(self, kheader: KHeader) -> KTrajectoryRawShape:
5253
n_samples = torch.unique(n_samples)
5354
if len(n_samples) > 1:
5455
raise ValueError('We currently only support constant number of samples')
55-
n_k0 = int(n_samples.item())
56-
57-
def rescale_and_reshape_traj(k_traj: torch.Tensor, encoding_size: int):
58-
if encoding_size > 1 and torch.max(torch.abs(k_traj)) > 0:
59-
k_traj = k_traj * encoding_size / (2 * torch.max(torch.abs(k_traj)))
60-
else:
61-
# We force k_traj to be 0 if encoding_size = 1. This is typically the case for kz in 2D sequences.
62-
# However, it happens that seq.calculate_kspace() returns values != 0 (numerical noise) in such cases.
63-
k_traj = torch.zeros_like(k_traj)
64-
return rearrange(k_traj, '(other k0) -> other k0', k0=n_k0)
65-
66-
# rearrange k-space trajectory to match MRpro convention
67-
kx = rescale_and_reshape_traj(k_traj_adc[0], kheader.encoding_matrix.x)
68-
ky = rescale_and_reshape_traj(k_traj_adc[1], kheader.encoding_matrix.y)
69-
kz = rescale_and_reshape_traj(k_traj_adc[2], kheader.encoding_matrix.z)
7056

71-
return KTrajectoryRawShape(kz, ky, kx, self.repeat_detection_tolerance)
57+
k_traj_reshaped = rearrange(k_traj_adc, 'xyz (other k0) -> xyz other k0', k0=int(n_samples.item()))
58+
return KTrajectoryRawShape.from_tensor(
59+
k_traj_reshaped,
60+
axes_order='xyz',
61+
scaling_matrix=kheader.encoding_matrix,
62+
repeat_detection_tolerance=self.repeat_detection_tolerance,
63+
)

tests/data/_PulseqRadialTestSeq.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ def __init__(self, seq_filename: str, n_x=256, n_spokes=10):
2929

3030
system = pypulseq.Opts()
3131
rf, gz, _ = pypulseq.make_sinc_pulse(flip_angle=0.1, slice_thickness=1e-3, system=system, return_gz=True)
32-
gx = pypulseq.make_trapezoid(channel='x', flat_area=n_x * delta_k, flat_time=2e-3, system=system)
32+
gx = pypulseq.make_trapezoid(
33+
channel='x', flat_area=n_x * delta_k, flat_time=n_x * system.grad_raster_time, system=system
34+
)
3335
adc = pypulseq.make_adc(num_samples=n_x, duration=gx.flat_time, delay=gx.rise_time, system=system)
3436
gx_pre = pypulseq.make_trapezoid(channel='x', area=-gx.area / 2 - delta_k / 2, duration=2e-3, system=system)
3537
gz_reph = pypulseq.make_trapezoid(channel='z', area=-gz.area / 2, duration=2e-3, system=system)

tests/data/test_traj_calculators.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -260,11 +260,11 @@ def test_KTrajectoryPulseq_validseq_random_header(pulseq_example_rad_seq, valid_
260260
trajectory_calculator = KTrajectoryPulseq(seq_path=pulseq_example_rad_seq.seq_filename)
261261
trajectory = trajectory_calculator(kheader=valid_rad2d_kheader)
262262

263-
kx_test = pulseq_example_rad_seq.traj_analytical.kx.squeeze(0).squeeze(0)
264-
kx_test *= valid_rad2d_kheader.encoding_matrix.x / (2 * torch.max(torch.abs(kx_test)))
263+
kx_test = pulseq_example_rad_seq.traj_analytical.kx.squeeze()
264+
kx_test = kx_test * valid_rad2d_kheader.encoding_matrix.x / (2 * kx_test.abs().max())
265265

266-
ky_test = pulseq_example_rad_seq.traj_analytical.ky.squeeze(0).squeeze(0)
267-
ky_test *= valid_rad2d_kheader.encoding_matrix.y / (2 * torch.max(torch.abs(ky_test)))
266+
ky_test = pulseq_example_rad_seq.traj_analytical.ky.squeeze()
267+
ky_test = ky_test * valid_rad2d_kheader.encoding_matrix.y / (2 * ky_test.abs().max())
268268

269269
torch.testing.assert_close(trajectory.kx.to(torch.float32), kx_test.to(torch.float32), atol=1e-2, rtol=1e-3)
270270
torch.testing.assert_close(trajectory.ky.to(torch.float32), ky_test.to(torch.float32), atol=1e-2, rtol=1e-3)

0 commit comments

Comments
 (0)