Skip to content

Commit d21362a

Browse files
Add reshape_infer procedure (#646)
Unlike numpy, `reshape` does not support having dimensions with value -1 to infer their value. To do so a new `reshape_infer` is added. This is added as a separate procedure to avoid the (small) cost this adds on top of the usual reshape (which could be called relatively frequently).
1 parent 9867253 commit d21362a

File tree

3 files changed

+52
-7
lines changed

3 files changed

+52
-7
lines changed

src/arraymancer/tensor/private/p_shapeshifting.nim

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ import ../../laser/tensor/initialization,
1818
./p_checks,
1919
nimblas
2020

21+
import std / sequtils
22+
2123
proc contiguousImpl*[T](t: Tensor[T], layout: OrderType, result: var Tensor[T]) =
2224
if layout == rowMajor:
2325
result = t.map_inline(x)
@@ -28,16 +30,39 @@ proc contiguousImpl*[T](t: Tensor[T], layout: OrderType, result: var Tensor[T])
2830
apply2_inline(result, t):
2931
y
3032
31-
proc reshape_with_copy*[T](t: Tensor[T], new_shape: varargs[int]|Metadata, result: var Tensor[T]) =
33+
proc reshape_with_copy*[T](t: Tensor[T], new_shape: varargs[int]|Metadata|seq[int], result: var Tensor[T]) =
3234
result = newTensorUninit[T](new_shape)
3335
result.apply2_inline(t,y)
3436
35-
proc reshape_no_copy*(t: AnyTensor, new_shape: varargs[int]|Metadata, result: var AnyTensor, layout: OrderType) {.noSideEffect.}=
37+
proc reshape_no_copy*(t: AnyTensor, new_shape: varargs[int]|Metadata|seq[int], result: var AnyTensor, layout: OrderType) {.noSideEffect.}=
3638
result.shape.copyFrom(new_shape)
3739
shape_to_strides(result.shape, layout, result.strides)
3840
result.offset = t.offset
3941
40-
proc reshapeImpl*(t: AnyTensor, new_shape: varargs[int]|Metadata, result: var AnyTensor) =
42+
proc infer_shape*(t: Tensor, new_shape: varargs[int]): seq[int] {.noinit.} =
43+
## Replace the single -1 value on `new_shape` with the value that
44+
## makes the size the same as that of the input tensor
45+
result = new_shape.toSeq
46+
var auto_axis = -1
47+
var auto_axis_count = 0
48+
for n in 0 .. result.high:
49+
if result[n] == -1:
50+
auto_axis_count += 1
51+
auto_axis = n
52+
break
53+
if auto_axis_count > 1:
54+
raise newException(ValueError, "Only one dimension can be inferred by inferShape")
55+
elif auto_axis_count == 0:
56+
when compileOption("boundChecks"):
57+
raise newException(ValueError, "At least one dimension must be inferred by inferShape")
58+
else:
59+
result[auto_axis] = t.size div result.filterIt(it != -1).prod
60+
61+
proc reshapeImpl*(t: AnyTensor, new_shape: varargs[int]|Metadata|seq[int],
62+
result: var AnyTensor, infer: static bool) =
63+
when infer:
64+
let new_shape = t.infer_shape(new_shape)
65+
4166
when compileOption("boundChecks"):
4267
check_reshape(t, new_shape)
4368

src/arraymancer/tensor/shapeshifting.nim

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,12 @@ proc reshape*(t: Tensor, new_shape: varargs[int]): Tensor {.noinit.} =
6363
##
6464
## Input:
6565
## - a tensor
66-
## - a new shape. Number of elements must be the same
66+
## - a new shape. Number of elements must be the same. Unlike numpy,
67+
## dimensions cannot be -1 to infer their value. If that is what you need
68+
## you must use the alternative `reshape_infer` proc.
6769
## Returns:
6870
## - a tensor with the same data but reshaped.
69-
reshapeImpl(t, new_shape, result)
71+
reshapeImpl(t, new_shape, result, infer = false)
7072
7173
proc reshape*(t: Tensor, new_shape: Metadata): Tensor {.noinit.} =
7274
## Reshape a tensor. If possible no data copy is done and the returned tensor
@@ -78,7 +80,21 @@ proc reshape*(t: Tensor, new_shape: Metadata): Tensor {.noinit.} =
7880
## - a new shape. Number of elements must be the same
7981
## Returns:
8082
## - a tensor with the same data but reshaped.
81-
reshapeImpl(t, new_shape, result)
83+
reshapeImpl(t, new_shape, result, infer = false)
84+
85+
proc reshape_infer*(t: Tensor, new_shape: varargs[int]):
86+
Tensor {.noinit.} =
87+
## Reshape a tensor. If possible no data copy is done and the returned tensor
88+
## shares data with the input. If input is not contiguous, this is not possible
89+
## and a copy will be made.
90+
##
91+
## Input:
92+
## - a tensor
93+
## - a new shape. Number of elements must be the same. The new shape can
94+
## contain -1 to infer the size of one (and only one) dimension
95+
## Returns:
96+
## - a tensor with the same data but reshaped.
97+
reshapeImpl(t, new_shape, result, infer = true)
8298
8399
proc flatten*(t: Tensor): Tensor {.noinit,inline.} =
84100
## Flatten a tensor, returning a rank-1 tensor with the same data as the input.

tests/tensor/test_shapeshifting.nim

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,13 @@ proc main() =
6262
check: a == b
6363

6464
test "Reshape":
65-
let a = toSeq(1..4).toTensor().reshape(2,2)
65+
let a = toSeq(1..4).toTensor().reshape(2, 2)
66+
let b = toSeq(1..4).toTensor().reshape_infer(-1, 2)
67+
let c = toSeq(1..4).toTensor().reshape_infer(2, -1)
6668
check: a == [[1,2],
6769
[3,4]].toTensor()
70+
check: a == b
71+
check: a == c
6872

6973
test "Unsafe reshape":
7074
block:

0 commit comments

Comments
 (0)