Skip to content

Commit 183ee88

Browse files
authored
Slice optimizations to reduce spills (#732)
* Slice optimizations to reduce spills * Removed `tensor.Slice()` in favor of more generic `slice()`
1 parent 98f83e4 commit 183ee88

File tree

23 files changed

+197
-110
lines changed

23 files changed

+197
-110
lines changed

docs_input/api/manipulation/basic/slice.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ to indicate starting at the end and going backward.
99

1010
When slicing along any given tensor dimension, the start index is treated as inclusive, and the end index as exclusive.
1111

12-
.. doxygenfunction:: slice(const OpType opIn, const index_t (&starts)[OpType::Rank()], const index_t (&ends)[OpType::Rank()])
13-
.. doxygenfunction:: slice(const OpType op, const index_t (&starts)[OpType::Rank()], const index_t (&ends)[OpType::Rank()], const index_t (&strides)[OpType::Rank()])
12+
.. doxygenfunction:: slice(const OpType &op, const index_t (&starts)[OpType::Rank()], const index_t (&ends)[OpType::Rank()], const index_t (&strides)[OpType::Rank()])
13+
.. doxygenfunction:: slice( const OpType &op, const index_t (&starts)[OpType::Rank()], const index_t (&ends)[OpType::Rank()])
1414

1515
Examples
1616
~~~~~~~~

docs_input/notebooks/exercises/example4_cfar.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
5353
radar.CFARDetections();
5454

5555
printf("FFT output:\n");
56-
print(radar.GetTPCData()->View().Slice<1>({0, 0, 0}, {matxSliceDim, matxSliceDim, 16}));
56+
print(slice<1>(radar.GetTPCData()->View(), {0, 0, 0}, {matxSliceDim, matxSliceDim, 16}));
5757

5858
cudaStreamDestroy(stream);
5959

docs_input/notebooks/exercises/example4_doppler.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
5353
radar.DopplerProcessing();
5454

5555
printf("Doppler output:\n");
56-
radar.GetTPCView().Slice<1>({0, 0, 0}, {matxSliceDim, matxSliceDim, 16}).rint();
56+
print(slice<1>(radar.GetTPCView(), {0, 0, 0}, {matxSliceDim, matxSliceDim, 16}));
5757

5858
cudaStreamDestroy(stream);
5959

docs_input/notebooks/exercises/example4_init.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
4949
cudaEventCreate(&stop);
5050

5151
auto radar = RadarPipeline(numPulses, numSamples, waveformLength, numChannels, stream);
52-
auto rv = radar.GetNormT().Slice<1>({0, 0, 0}, {matxSliceDim, matxSliceDim, 16});
53-
rv.print();
52+
auto rv = slice<1>(radar.GetNormT(), {0, 0, 0}, {matxSliceDim, matxSliceDim, 16});
53+
print(rv);
5454
cudaStreamDestroy(stream);
5555

5656
return 0;

docs_input/notebooks/exercises/example4_pc.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
5555

5656
radar.PulseCompression();
5757

58-
auto rv = radar.GetInputView().Slice<1>({0, 0, 0}, {matxSliceDim, matxSliceDim, 16});
59-
rv.print();
58+
auto rv = slice<1>(radar.GetInputView(), {0, 0, 0}, {matxSliceDim, matxSliceDim, 16});
59+
print(rv);
6060
cudaStreamDestroy(stream);
6161

6262
return 0;

docs_input/notebooks/exercises/example4_tpc.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
5353
radar.ThreePulseCanceller();
5454

5555
printf("x input:\n");
56-
radar.GetInputView().Slice<1>({0, 0, 0}, {matxSliceDim, matxSliceDim, 16}).Print();
56+
print(slice<1>(radar.GetInputView(), {0, 0, 0}, {matxSliceDim, matxSliceDim, 16}));
5757
printf("Convolution output:\n");
58-
radar.GetTPCView()->Slice<1>({0,0,0}, {matxSliceDim, matxSliceDim, 10}).Print();
58+
print(slice<1>(radar.GetTPCView(), {0,0,0}, {matxSliceDim, matxSliceDim, 10}));
5959
cudaStreamDestroy(stream);
6060

6161
return 0;

docs_input/notebooks/exercises/solutions/example1_assignment1.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ int main() {
6767
* Get a slice of the second and third rows with all columns
6868
* https://devtech-compute.gitlab-master-pages.nvidia.com/matx/quickstart.html#slicing-and-dicing
6969
*****************************************************************************************************/
70-
auto t2s = t2.Slice({1, 0}, {3, matxEnd}); // Put code here
70+
auto t2s = slice(t2, {1, 0}, {3, matxEnd}); // Put code here
7171
/*** End editing ***/
7272

7373
// Verify slice is correct

examples/fft_conv.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
149149
// Now the sig_freq view contains the full convolution result. Verify against
150150
// a direct convolution. The conv1d function only accepts a 1D filter, so we
151151
// create a sliced view here.
152-
auto filt1 = filt_time.Slice<1>({0,0}, {matxDropDim, matxEnd});
152+
auto filt1 = slice<1>(filt_time, {0,0}, {matxDropDim, matxEnd});
153153
(time_out = conv1d(sig_time, filt1, matxConvCorrMode_t::MATX_C_MODE_FULL)).run(exec);
154154

155155
exec.sync();

examples/mvdr_beamformer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ class MVDRBeamformer {
108108

109109
(cbfView = matmul(vhView, inVecView)).run(exec);
110110

111-
matx::copy(ivsView, inVecView.Slice({0, 0}, {matxEnd, snap_len_}), stream);
111+
matx::copy(ivsView, slice(inVecView, {0, 0}, {matxEnd, snap_len_}), stream);
112112

113113
(ivshView = hermitianT(ivsView)).run(exec);
114114

examples/resample.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
6969
(sigViewComplex = fft(sigView)).run(exec);
7070

7171
// Slice
72-
auto sliceView = sigViewComplex.Slice({0}, {nyq});
72+
auto sliceView = slice(sigViewComplex, {0}, {nyq});
7373

7474
// Inverse Transform - FFT size based on output
7575
(resampView = ifft(sliceView)).run(exec);
@@ -81,7 +81,7 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
8181
(sigViewComplex = fft(sigView)).run(exec);
8282

8383
// Slice
84-
auto sv = sigViewComplex.Slice({0}, {nyq});
84+
auto sv = slice(sigViewComplex, {0}, {nyq});
8585

8686
// Inverse Transform - FFT size based on output
8787
(resampView = ifft(sv)).run(exec);

examples/simple_radar_pipeline.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -237,11 +237,11 @@ class RadarPipeline {
237237
void PulseCompression()
238238
{
239239
// reshape waveform to be waveformLength
240-
auto waveformPart = waveformView.Slice({0}, {waveformLength});
240+
auto waveformPart = slice(waveformView, {0}, {waveformLength});
241241
auto waveformT =
242242
waveformView.template Clone<3>({numChannels, numPulses, matxKeepDim});
243243

244-
auto waveformFull = waveformView.Slice({0}, {numSamplesRnd});
244+
auto waveformFull = slice(waveformView, {0}, {numSamplesRnd});
245245

246246
auto x = inputView;
247247

@@ -285,9 +285,9 @@ class RadarPipeline {
285285
*/
286286
void ThreePulseCanceller()
287287
{
288-
auto x = inputView.Permute({0, 2, 1}).Slice(
288+
auto x = slice(inputView.Permute({0, 2, 1}),
289289
{0, 0, 0}, {numChannels, numCompressedSamples, numPulses});
290-
auto xo = tpcView.Permute({0, 2, 1}).Slice(
290+
auto xo = slice(tpcView.Permute({0, 2, 1}),
291291
{0, 0, 0}, {numChannels, numCompressedSamples, numPulses});
292292
(xo = conv1d(x, cancelMask, matxConvCorrMode_t::MATX_C_MODE_SAME)).run(exec);
293293
}
@@ -311,7 +311,7 @@ class RadarPipeline {
311311
const index_t cpulses = numPulses - (cancelMask.Size(0) - 1);
312312

313313
auto xc =
314-
tpcView.Slice({0, 0, 0}, {numChannels, cpulses, numCompressedSamples});
314+
slice(tpcView, {0, 0, 0}, {numChannels, cpulses, numCompressedSamples});
315315

316316
auto xf = tpcView.Permute({0, 2, 1});
317317

@@ -368,11 +368,11 @@ class RadarPipeline {
368368
// This can be done with a convolution of the cfarMask with
369369
// ones.
370370
// norm = conv2(ones(size(X)), mask, 'same');
371-
auto normTrim = normT.Slice({0, cfarMaskY / 2, cfarMaskX / 2},
371+
auto normTrim = slice(normT, {0, cfarMaskY / 2, cfarMaskX / 2},
372372
{numChannels, numPulsesRnd + cfarMaskY / 2,
373373
numCompressedSamples + cfarMaskX / 2});
374374

375-
auto baTrim = ba.Slice({0, cfarMaskY / 2, cfarMaskX / 2},
375+
auto baTrim = slice(ba, {0, cfarMaskY / 2, cfarMaskX / 2},
376376
{numChannels, numPulsesRnd + cfarMaskY / 2,
377377
numCompressedSamples + cfarMaskX / 2});
378378
(baTrim = baTrim / normTrim).run(exec);

include/matx/core/tensor.h

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1429,6 +1429,8 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
14291429
* more dimensions of a tensor. This includes completely dropping an unwanted
14301430
* dimension, or simply taking a piece of a wanted dimension. Slice() is very
14311431
* similar to indexing operations in both Python and MATLAB.
1432+
*
1433+
* *NOTE* Users should not call Slice() directly anymore. Use the slice() operator instead.
14321434
*
14331435
* @param firsts
14341436
* List of starting index into each dimension. Indexing is 0-based
@@ -1451,10 +1453,10 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
14511453
* @returns Sliced view of tensor
14521454
*
14531455
*/
1454-
template <int N = RANK>
1456+
template <int N = RANK, typename StrideType>
14551457
__MATX_INLINE__ auto Slice([[maybe_unused]] const cuda::std::array<typename Desc::shape_type, RANK> &firsts,
1456-
[[maybe_unused]] const cuda::std::array<typename Desc::shape_type, RANK> &ends,
1457-
[[maybe_unused]] const cuda::std::array<typename Desc::stride_type, RANK> &strides) const
1458+
[[maybe_unused]] const cuda::std::array<typename Desc::shape_type, RANK> &ends,
1459+
[[maybe_unused]] StrideType strides) const
14581460
{
14591461
static_assert(N <= RANK && RANK > 0, "Must slice to a rank the same or less than current rank.");
14601462

@@ -1465,7 +1467,6 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
14651467

14661468
T *data = this->ldata_;
14671469
int d = 0;
1468-
bool def_stride = (strides[0] == -1);
14691470

14701471
[[maybe_unused]] int end_count = 0;
14711472
for (int i = 0; i < RANK; i++) {
@@ -1487,9 +1488,14 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
14871488

14881489
MATX_ASSERT_STR(first < end, matxInvalidSize, "Slice must be at least one element long");
14891490

1490-
[[maybe_unused]] typename Desc::stride_type stride_mult = (def_stride || strides[i] == matxKeepStride)
1491-
? 1
1492-
: strides[i]; // custom stride
1491+
[[maybe_unused]] typename Desc::stride_type stride_mult;
1492+
1493+
if constexpr (std::is_same_v<StrideType, detail::NoStride>) {
1494+
stride_mult = 1;
1495+
}
1496+
else {
1497+
stride_mult = (strides[i] == matxKeepStride) ? 1 : strides[i];
1498+
}
14931499

14941500
MATX_ASSERT_STR(first < end, matxInvalidParameter,
14951501
"Starting slice must be less than end slice");
@@ -1526,10 +1532,10 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
15261532
return tensor_t<T, N, Storage, decltype(new_desc)>{storage_, std::move(new_desc), data};
15271533
}
15281534

1529-
template <int N = RANK>
1535+
template <typename StrideType, int N = RANK>
15301536
__MATX_INLINE__ auto Slice(const typename Desc::shape_type (&firsts)[RANK],
1531-
const typename Desc::shape_type (&ends)[RANK],
1532-
const typename Desc::stride_type (&strides)[RANK]) const
1537+
const typename Desc::shape_type (&ends)[RANK],
1538+
StrideType strides) const
15331539
{
15341540
return Slice<N>(detail::to_array(firsts), detail::to_array(ends), detail::to_array(strides));
15351541
}
@@ -1560,15 +1566,13 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
15601566
*/
15611567
template <int N = RANK>
15621568
__MATX_INLINE__ auto Slice(const cuda::std::array<typename Desc::shape_type, RANK> &firsts,
1563-
const cuda::std::array<typename Desc::shape_type, RANK> &ends) const
1569+
const cuda::std::array<typename Desc::shape_type, RANK> &ends) const
15641570
{
15651571
static_assert(N <= RANK && RANK > 0, "Must slice to a rank the same or less than current rank.");
15661572

15671573
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
15681574

1569-
const cuda::std::array<typename Desc::stride_type, RANK> strides = {-1};
1570-
1571-
return Slice<N>(firsts, ends, strides);
1575+
return Slice<N, detail::NoStride>(firsts, ends, detail::NoStride{});
15721576
}
15731577

15741578
template <int N = RANK>

include/matx/core/type_utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ enum class MemoryLayout {
6666
namespace detail {
6767
struct NoShape{};
6868
struct EmptyOp{};
69+
struct NoStride{};
6970

7071
template <typename T>
7172
struct is_noshape : std::integral_constant<bool, std::is_same_v<NoShape, T>> {};

include/matx/operators/dct.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ void dct(OutputTensor &out, const InputTensor &in,
104104
tensor_t<cuda::std::complex<typename OutputTensor::value_type>, 1> tmp{{N + 1}};
105105

106106
fft_impl(tmp, in, 0, FFTNorm::BACKWARD, stream);
107-
auto s = tmp.Slice({0}, {N});
107+
auto s = slice(tmp, {0}, {N});
108108
detail::dctOp(out, s, N).run(stream);
109109
}
110110

0 commit comments

Comments
 (0)