Skip to content

Commit e15faa1

Browse files
Add memory efficient reshape of broadcasted tensors (#557)
1 parent 31756e3 commit e15faa1

File tree

3 files changed

+136
-2
lines changed

3 files changed

+136
-2
lines changed

src/mrpro/utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66
from mrpro.utils.remove_repeat import remove_repeat
77
from mrpro.utils.zero_pad_or_crop import zero_pad_or_crop
88
from mrpro.utils.split_idx import split_idx
9-
from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right, reduce_view
9+
from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right, reduce_view, reshape_broadcasted
1010
import mrpro.utils.unit_conversion
1111

1212
__all__ = [
1313
"broadcast_right",
1414
"fill_range_",
1515
"reduce_view",
1616
"remove_repeat",
17+
"reshape_broadcasted",
1718
"slice_profiles",
1819
"smap",
1920
"split_idx",

src/mrpro/utils/reshape.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Tensor reshaping utilities."""
22

33
from collections.abc import Sequence
4+
from functools import lru_cache
5+
from math import prod
46

57
import torch
68

@@ -99,3 +101,102 @@ def reduce_view(x: torch.Tensor, dim: int | Sequence[int] | None = None) -> torc
99101
for d, (oldsize, stride) in enumerate(zip(x.size(), stride, strict=True))
100102
]
101103
return torch.as_strided(x, newsize, stride)
104+
105+
106+
@lru_cache
107+
def _reshape_idx(old_shape: tuple[int, ...], new_shape: tuple[int, ...], old_stride: tuple[int, ...]) -> list[slice]:
108+
"""Get reshape reduce index (Cached helper function for reshape_broadcasted).
109+
110+
This function tries to group axes from new_shape and old_shape into the smallest groups that have
111+
the same number of elements, starting from the right.
112+
If all axes of old shape of a group are stride=0 dimensions, we can reduce them.
113+
114+
Example:
115+
old_shape = (30, 2, 2, 3)
116+
new_shape = (6, 5, 4, 3)
117+
Will results in the groups (starting from the right):
118+
- old: 3 new: 3
119+
- old: 2, 2 new: 4
120+
- old: 30 new: 6, 5
121+
Only the "old" groups are important.
122+
If all axes that are grouped together in an "old" group are stride 0 (=broadcasted)
123+
we can collapse them to singleton dimensions.
124+
This function returns the indexer that either collapses dimensions to singleton or keeps all
125+
elements, i.e. the slices in the returned list are all either slice(1) or slice(None).
126+
"""
127+
idx = []
128+
pointer_old, pointer_new = len(old_shape) - 1, len(new_shape) - 1 # start from the right
129+
while pointer_old >= 0:
130+
product_new, product_old = 1, 1 # the number of elements in the current "new" and "old" group
131+
group: list[int] = []
132+
while product_old != product_new or not group:
133+
if product_old <= product_new:
134+
# increase "old" group
135+
product_old *= old_shape[pointer_old]
136+
group.append(pointer_old)
137+
pointer_old -= 1
138+
else:
139+
# increase "new" group
140+
# we don't need to track the new group, the number of elemeents covered.
141+
product_new *= new_shape[pointer_new]
142+
pointer_new -= 1
143+
# we found a group. now we need to decide what to do.
144+
if all(old_stride[d] == 0 for d in group):
145+
# all dimensions are broadcasted
146+
# -> reduce to singleton
147+
idx.extend([slice(1)] * len(group))
148+
else:
149+
# preserve dimension
150+
idx.extend([slice(None)] * len(group))
151+
idx = idx[::-1] # we worked right to left, but our index should be left to right
152+
return idx
153+
154+
155+
def reshape_broadcasted(tensor: torch.Tensor, *shape: int) -> torch.Tensor:
156+
"""Reshape a tensor while preserving broadcasted (stride 0) dimensions where possible.
157+
158+
Parameters
159+
----------
160+
tensor
161+
The input tensor to reshape.
162+
shape
163+
The target shape for the tensor. One of the values can be `-1` and its size will be inferred.
164+
165+
Returns
166+
-------
167+
A tensor reshaped to the target shape, preserving broadcasted dimensions where feasible.
168+
169+
"""
170+
try:
171+
# if we can view the tensor directly, it will preserve broadcasting
172+
return tensor.view(shape)
173+
except RuntimeError:
174+
# we cannot do a view, we need to do more work:
175+
176+
# -1 means infer size, i.e. the remaining elements of the input not already covered by the other axes.
177+
negative_ones = shape.count(-1)
178+
size = tensor.shape.numel()
179+
if not negative_ones:
180+
if prod(shape) != size:
181+
# use same exception as pytorch
182+
raise RuntimeError(f"shape '{list(shape)}' is invalid for input of size {size}") from None
183+
elif negative_ones > 1:
184+
raise RuntimeError('only one dimension can be inferred') from None
185+
elif negative_ones == 1:
186+
# we need to figure out the size of the "-1" dimension
187+
known_size = -prod(shape) # negative, is it includes the -1
188+
if size % known_size:
189+
# non integer result. no possible size of the -1 axis exists.
190+
raise RuntimeError(f"shape '{list(shape)}' is invalid for input of size {size}") from None
191+
shape = tuple(size // known_size if s == -1 else s for s in shape)
192+
193+
# most of the broadcasted dimensions can be preserved: only dimensions that are joined with non
194+
# broadcasted dimensions can not be preserved and must be made contiguous.
195+
# all dimensions that can be preserved as broadcasted are first collapsed to singleton,
196+
# such that contiguous does not create copies along these axes.
197+
idx = _reshape_idx(tensor.shape, shape, tensor.stride())
198+
# make contiguous only in dimensions in which broadcasting cannot be preserved
199+
semicontiguous = tensor[idx].contiguous()
200+
# finally, we can expand the broadcasted dimensions to the requested shape
201+
semicontiguous = semicontiguous.expand(tensor.shape)
202+
return semicontiguous.view(shape)

tests/utils/test_reshape.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""Tests for reshaping utilities."""
22

3+
import pytest
34
import torch
4-
from mrpro.utils import broadcast_right, reduce_view, unsqueeze_left, unsqueeze_right
5+
from mrpro.utils import broadcast_right, reduce_view, reshape_broadcasted, unsqueeze_left, unsqueeze_right
56

67
from tests import RandomGenerator
78

@@ -51,3 +52,34 @@ def test_reduce_view():
5152
reduced_one_pos = reduce_view(tensor, 0)
5253
assert reduced_one_pos.shape == (1, 2, 3, 4, 5, 6)
5354
assert torch.equal(reduced_one_pos.expand_as(tensor), tensor)
55+
56+
57+
@pytest.mark.parametrize(
58+
('shape', 'expand_shape', 'permute', 'final_shape', 'expected_stride'),
59+
[
60+
((1, 2, 3, 1, 1), (1, 2, 3, 4, 5), (0, 2, 1, 3, 4), (1, 6, 2, 2, 5), (6, 1, 0, 0, 0)),
61+
((1, 2, 1), (100, 2, 2), (0, 1, 2), (100, 4), (0, 1)),
62+
((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 0, 1), (1, 2, 6, 10, 1), (0, 0, 0, 0, 0)),
63+
((1, 2, 3), (1, -1, 3), (0, 1, 2), (6,), (1,)),
64+
],
65+
)
66+
def test_reshape_broadcasted(shape, expand_shape, permute, final_shape, expected_stride):
67+
"""Test reshape_broadcasted"""
68+
rng = RandomGenerator(0)
69+
tensor = rng.float32_tensor(shape).expand(*expand_shape).permute(*permute)
70+
reshaped = reshape_broadcasted(tensor, *final_shape)
71+
expected_values = tensor.reshape(*final_shape)
72+
assert reshaped.shape == expected_values.shape
73+
assert reshaped.stride() == expected_stride
74+
assert torch.equal(reshaped, expected_values)
75+
76+
77+
def test_reshape_broadcasted_fail():
78+
"""Test reshape_broadcasted with invalid input"""
79+
a = torch.ones(2)
80+
with pytest.raises(RuntimeError, match='invalid'):
81+
reshape_broadcasted(a, 3)
82+
with pytest.raises(RuntimeError, match='invalid'):
83+
reshape_broadcasted(a, -1, -3)
84+
with pytest.raises(RuntimeError, match='only one dimension'):
85+
reshape_broadcasted(a, -1, -1)

0 commit comments

Comments
 (0)