@@ -18,6 +18,8 @@ import ../../laser/tensor/initialization,
18
18
./ p_checks,
19
19
nimblas
20
20
21
+ import std / sequtils
22
+
21
23
proc contiguousImpl* [T](t: Tensor[T], layout: OrderType, result : var Tensor[T]) =
22
24
if layout == rowMajor:
23
25
result = t.map_inline(x)
@@ -28,16 +30,39 @@ proc contiguousImpl*[T](t: Tensor[T], layout: OrderType, result: var Tensor[T])
28
30
apply2_inline(result , t):
29
31
y
30
32
31
- proc reshape_with_copy*[T](t: Tensor[T], new_shape: varargs [int ]|Metadata, result : var Tensor[T]) =
33
+ proc reshape_with_copy*[T](t: Tensor[T], new_shape: varargs [int ]|Metadata| seq [ int ] , result : var Tensor[T]) =
32
34
result = newTensorUninit[T](new_shape)
33
35
result .apply2_inline(t,y)
34
36
35
- proc reshape_no_copy*(t: AnyTensor, new_shape: varargs [int ]|Metadata, result : var AnyTensor, layout: OrderType) {.noSideEffect.}=
37
+ proc reshape_no_copy*(t: AnyTensor, new_shape: varargs [int ]|Metadata| seq [ int ] , result : var AnyTensor, layout: OrderType) {.noSideEffect.}=
36
38
result .shape.copyFrom(new_shape)
37
39
shape_to_strides(result .shape, layout, result .strides)
38
40
result .offset = t.offset
39
41
40
- proc reshapeImpl*(t: AnyTensor, new_shape: varargs [int ]|Metadata, result : var AnyTensor) =
42
+ proc infer_shape*(t: Tensor, new_shape: varargs [int ]): seq [int ] {.noinit.} =
43
+ ## Replace the single -1 value on `new_shape` with the value that
44
+ ## makes the size the same as that of the input tensor
45
+ result = new_shape.toSeq
46
+ var auto_axis = -1
47
+ var auto_axis_count = 0
48
+ for n in 0 .. result .high:
49
+ if result [n] == -1:
50
+ auto_axis_count += 1
51
+ auto_axis = n
52
+ break
53
+ if auto_axis_count > 1:
54
+ raise newException(ValueError, "Only one dimension can be inferred by inferShape")
55
+ elif auto_axis_count == 0:
56
+ when compileOption("boundChecks"):
57
+ raise newException(ValueError, "At least one dimension must be inferred by inferShape")
58
+ else:
59
+ result [auto_axis] = t.size div result .filterIt(it != -1).prod
60
+
61
+ proc reshapeImpl*(t: AnyTensor, new_shape: varargs [int ]|Metadata|seq [int ],
62
+ result : var AnyTensor, infer: static bool ) =
63
+ when infer:
64
+ let new_shape = t.infer_shape(new_shape)
65
+
41
66
when compileOption("boundChecks"):
42
67
check_reshape(t, new_shape)
43
68
0 commit comments