|
1 | 1 | """Tensor reshaping utilities."""
|
2 | 2 |
|
3 | 3 | from collections.abc import Sequence
|
| 4 | +from functools import lru_cache |
| 5 | +from math import prod |
4 | 6 |
|
5 | 7 | import torch
|
6 | 8 |
|
@@ -99,3 +101,102 @@ def reduce_view(x: torch.Tensor, dim: int | Sequence[int] | None = None) -> torc
|
99 | 101 | for d, (oldsize, stride) in enumerate(zip(x.size(), stride, strict=True))
|
100 | 102 | ]
|
101 | 103 | return torch.as_strided(x, newsize, stride)
|
| 104 | + |
| 105 | + |
| 106 | +@lru_cache |
| 107 | +def _reshape_idx(old_shape: tuple[int, ...], new_shape: tuple[int, ...], old_stride: tuple[int, ...]) -> list[slice]: |
| 108 | + """Get reshape reduce index (Cached helper function for reshape_broadcasted). |
| 109 | +
|
| 110 | + This function tries to group axes from new_shape and old_shape into the smallest groups that have |
| 111 | + the same number of elements, starting from the right. |
| 112 | + If all axes of old shape of a group are stride=0 dimensions, we can reduce them. |
| 113 | +
|
| 114 | + Example: |
| 115 | + old_shape = (30, 2, 2, 3) |
| 116 | + new_shape = (6, 5, 4, 3) |
| 117 | + Will results in the groups (starting from the right): |
| 118 | + - old: 3 new: 3 |
| 119 | + - old: 2, 2 new: 4 |
| 120 | + - old: 30 new: 6, 5 |
| 121 | + Only the "old" groups are important. |
| 122 | + If all axes that are grouped together in an "old" group are stride 0 (=broadcasted) |
| 123 | + we can collapse them to singleton dimensions. |
| 124 | + This function returns the indexer that either collapses dimensions to singleton or keeps all |
| 125 | + elements, i.e. the slices in the returned list are all either slice(1) or slice(None). |
| 126 | + """ |
| 127 | + idx = [] |
| 128 | + pointer_old, pointer_new = len(old_shape) - 1, len(new_shape) - 1 # start from the right |
| 129 | + while pointer_old >= 0: |
| 130 | + product_new, product_old = 1, 1 # the number of elements in the current "new" and "old" group |
| 131 | + group: list[int] = [] |
| 132 | + while product_old != product_new or not group: |
| 133 | + if product_old <= product_new: |
| 134 | + # increase "old" group |
| 135 | + product_old *= old_shape[pointer_old] |
| 136 | + group.append(pointer_old) |
| 137 | + pointer_old -= 1 |
| 138 | + else: |
| 139 | + # increase "new" group |
| 140 | + # we don't need to track the new group, the number of elemeents covered. |
| 141 | + product_new *= new_shape[pointer_new] |
| 142 | + pointer_new -= 1 |
| 143 | + # we found a group. now we need to decide what to do. |
| 144 | + if all(old_stride[d] == 0 for d in group): |
| 145 | + # all dimensions are broadcasted |
| 146 | + # -> reduce to singleton |
| 147 | + idx.extend([slice(1)] * len(group)) |
| 148 | + else: |
| 149 | + # preserve dimension |
| 150 | + idx.extend([slice(None)] * len(group)) |
| 151 | + idx = idx[::-1] # we worked right to left, but our index should be left to right |
| 152 | + return idx |
| 153 | + |
| 154 | + |
| 155 | +def reshape_broadcasted(tensor: torch.Tensor, *shape: int) -> torch.Tensor: |
| 156 | + """Reshape a tensor while preserving broadcasted (stride 0) dimensions where possible. |
| 157 | +
|
| 158 | + Parameters |
| 159 | + ---------- |
| 160 | + tensor |
| 161 | + The input tensor to reshape. |
| 162 | + shape |
| 163 | + The target shape for the tensor. One of the values can be `-1` and its size will be inferred. |
| 164 | +
|
| 165 | + Returns |
| 166 | + ------- |
| 167 | + A tensor reshaped to the target shape, preserving broadcasted dimensions where feasible. |
| 168 | +
|
| 169 | + """ |
| 170 | + try: |
| 171 | + # if we can view the tensor directly, it will preserve broadcasting |
| 172 | + return tensor.view(shape) |
| 173 | + except RuntimeError: |
| 174 | + # we cannot do a view, we need to do more work: |
| 175 | + |
| 176 | + # -1 means infer size, i.e. the remaining elements of the input not already covered by the other axes. |
| 177 | + negative_ones = shape.count(-1) |
| 178 | + size = tensor.shape.numel() |
| 179 | + if not negative_ones: |
| 180 | + if prod(shape) != size: |
| 181 | + # use same exception as pytorch |
| 182 | + raise RuntimeError(f"shape '{list(shape)}' is invalid for input of size {size}") from None |
| 183 | + elif negative_ones > 1: |
| 184 | + raise RuntimeError('only one dimension can be inferred') from None |
| 185 | + elif negative_ones == 1: |
| 186 | + # we need to figure out the size of the "-1" dimension |
| 187 | + known_size = -prod(shape) # negative, is it includes the -1 |
| 188 | + if size % known_size: |
| 189 | + # non integer result. no possible size of the -1 axis exists. |
| 190 | + raise RuntimeError(f"shape '{list(shape)}' is invalid for input of size {size}") from None |
| 191 | + shape = tuple(size // known_size if s == -1 else s for s in shape) |
| 192 | + |
| 193 | + # most of the broadcasted dimensions can be preserved: only dimensions that are joined with non |
| 194 | + # broadcasted dimensions can not be preserved and must be made contiguous. |
| 195 | + # all dimensions that can be preserved as broadcasted are first collapsed to singleton, |
| 196 | + # such that contiguous does not create copies along these axes. |
| 197 | + idx = _reshape_idx(tensor.shape, shape, tensor.stride()) |
| 198 | + # make contiguous only in dimensions in which broadcasting cannot be preserved |
| 199 | + semicontiguous = tensor[idx].contiguous() |
| 200 | + # finally, we can expand the broadcasted dimensions to the requested shape |
| 201 | + semicontiguous = semicontiguous.expand(tensor.shape) |
| 202 | + return semicontiguous.view(shape) |
0 commit comments