Skip to content

Commit c763f43

Browse files
authored
Merge pull request #155 from feizheng10/develop
optimized transpose kernel for non-corner case
2 parents 273c18b + 576382a commit c763f43

File tree

2 files changed

+162
-51
lines changed

2 files changed

+162
-51
lines changed

library/src/device/kernels/transpose.h

Lines changed: 109 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#ifndef TRANSPOSE_H
1+
#ifndef TRANSPOSE_H
22
#define TRANSPOSE_H
33

44
#include "rocfft_hip.h"
@@ -15,20 +15,20 @@
1515
DIM_X is divisible by DIM_Y
1616
*/
1717

18-
template<typename T, size_t DIM_X, size_t DIM_Y, bool WITH_TWL, int TWL, int DIR>
18+
template<typename T, size_t DIM_X, size_t DIM_Y, bool WITH_TWL, int TWL, int DIR, bool ALL>
1919
__device__ void
2020
transpose_tile_device(const T* input, T* output, const size_t m, const size_t n, size_t gx, size_t gy, size_t ld_in, size_t ld_out, T *twiddles_large)
2121
{
22-
2322
__shared__ T shared_A[DIM_X][DIM_X];
2423

2524
size_t tid = hipThreadIdx_x + hipThreadIdx_y * hipBlockDim_x;
2625
size_t tx1 = tid % DIM_X;
2726
size_t ty1 = tid / DIM_X;
28-
29-
for(size_t i=0; i<m; i+=DIM_Y)
27+
28+
if (ALL)
3029
{
31-
if( tx1 < n && (ty1 + i) < m)
30+
#pragma unroll
31+
for(int i=0; i<DIM_X; i+=DIM_Y)
3232
{
3333
T tmp = input[tx1 + (ty1 + i) * ld_in];
3434
if (WITH_TWL)
@@ -37,7 +37,7 @@ transpose_tile_device(const T* input, T* output, const size_t m, const size_t n,
3737
{
3838
if(DIR == -1)
3939
{
40-
TWIDDLE_STEP_MUL_FWD(TWLstep2, twiddles_large, (gx + tx1)*(gy + ty1 + i), tmp);
40+
TWIDDLE_STEP_MUL_FWD(TWLstep2, twiddles_large, (gx + tx1)*(gy + ty1 + i), tmp);
4141
}
4242
else
4343
{
@@ -70,19 +70,75 @@ transpose_tile_device(const T* input, T* output, const size_t m, const size_t n,
7070

7171
shared_A[tx1][ty1+i] = tmp; // the transpose taking place here
7272
}
73-
}
7473

75-
__syncthreads();
74+
__syncthreads();
7675

77-
for(size_t i=0; i<n; i+=DIM_Y)
78-
{
79-
//reconfigure the threads
80-
if( tx1 < m && (ty1 + i)< n)
76+
#pragma unroll
77+
for(int i=0; i<DIM_X; i+=DIM_Y)
8178
{
79+
//reconfigure the threads
8280
output[tx1 + (i + ty1) * ld_out] = shared_A[ty1+i][tx1];
8381
}
8482
}
83+
else
84+
{
85+
for(size_t i=0; i<m; i+=DIM_Y)
86+
{
87+
if( tx1 < n && (ty1 + i) < m)
88+
{
89+
T tmp = input[tx1 + (ty1 + i) * ld_in];
90+
if (WITH_TWL)
91+
{
92+
if(TWL == 2)
93+
{
94+
if(DIR == -1)
95+
{
96+
TWIDDLE_STEP_MUL_FWD(TWLstep2, twiddles_large, (gx + tx1)*(gy + ty1 + i), tmp);
97+
}
98+
else
99+
{
100+
TWIDDLE_STEP_MUL_INV(TWLstep2, twiddles_large, (gx + tx1)*(gy + ty1 + i), tmp);
101+
}
102+
}
103+
else if(TWL == 3)
104+
{
105+
if(DIR == -1)
106+
{
107+
TWIDDLE_STEP_MUL_FWD(TWLstep3, twiddles_large, (gx + tx1)*(gy + ty1 + i), tmp);
108+
}
109+
else
110+
{
111+
TWIDDLE_STEP_MUL_INV(TWLstep3, twiddles_large, (gx + tx1)*(gy + ty1 + i), tmp);
112+
}
113+
}
114+
else if(TWL == 4)
115+
{
116+
if(DIR == -1)
117+
{
118+
TWIDDLE_STEP_MUL_FWD(TWLstep4, twiddles_large, (gx + tx1)*(gy + ty1 + i), tmp);
119+
}
120+
else
121+
{
122+
TWIDDLE_STEP_MUL_INV(TWLstep4, twiddles_large, (gx + tx1)*(gy + ty1 + i), tmp);
123+
}
124+
}
125+
}
126+
127+
shared_A[tx1][ty1+i] = tmp; // the transpose taking place here
128+
}
129+
}
130+
131+
__syncthreads();
85132

133+
for(size_t i=0; i<n; i+=DIM_Y)
134+
{
135+
//reconfigure the threads
136+
if( tx1 < m && (ty1 + i)< n)
137+
{
138+
output[tx1 + (i + ty1) * ld_out] = shared_A[ty1+i][tx1];
139+
}
140+
}
141+
}
86142
}
87143

88144
/*
@@ -96,79 +152,96 @@ transpose_tile_device(const T* input, T* output, const size_t m, const size_t n,
96152

97153

98154

99-
template<typename T, size_t DIM_X, size_t DIM_Y, bool WITH_TWL, int TWL, int DIR>
155+
template<typename T, size_t DIM_X, size_t DIM_Y, bool WITH_TWL, int TWL, int DIR, bool ALL>
100156
__global__ void
101-
transpose_kernel2(const T* input, T* output, T *twiddles_large, size_t dim, size_t *lengths, size_t *stride_in, size_t *stride_out)
157+
transpose_kernel2(const T* input, T* output, T *twiddles_large,
158+
size_t dim, size_t *lengths, size_t *stride_in, size_t *stride_out)
102159
{
103-
size_t m = lengths[1];
104-
size_t n = lengths[0];
105160
size_t ld_in = stride_in[1];
106161
size_t ld_out = stride_out[1];
107162

108163
size_t iOffset = 0;
109164
size_t oOffset = 0;
110-
165+
111166
size_t counter_mod = hipBlockIdx_z;
112-
167+
113168
for(size_t i = dim; i>2; i--){
114169
size_t currentLength = 1;
115170
for(size_t j=2; j<i; j++){
116171
currentLength *= lengths[j];
117172
}
118-
173+
119174
iOffset += (counter_mod / currentLength)*stride_in[i];
120175
oOffset += (counter_mod / currentLength)*stride_out[i];
121176
counter_mod = counter_mod % currentLength;
122177
}
123178
iOffset+= counter_mod * stride_in[2];
124179
oOffset+= counter_mod * stride_out[2];
125-
126180

127181
input += hipBlockIdx_x * DIM_X + hipBlockIdx_y * DIM_X * ld_in + iOffset;
128182
output += hipBlockIdx_x * DIM_X * ld_out + hipBlockIdx_y * DIM_X + oOffset;
129183

130-
size_t mm = min(m - hipBlockIdx_y * DIM_X, DIM_X); // the corner case along m
131-
size_t nn = min(n - hipBlockIdx_x * DIM_X, DIM_X); // the corner case along n
132-
133-
transpose_tile_device<T, DIM_X, DIM_Y, WITH_TWL, TWL, DIR>(input, output, mm, nn, hipBlockIdx_x * DIM_X, hipBlockIdx_y * DIM_X, ld_in, ld_out, twiddles_large);
184+
if (ALL)
185+
{
186+
transpose_tile_device<T, DIM_X, DIM_Y, WITH_TWL, TWL, DIR, ALL>(input, output, DIM_X, DIM_X,
187+
hipBlockIdx_x * DIM_X, hipBlockIdx_y * DIM_X, ld_in, ld_out, twiddles_large);
188+
}
189+
else
190+
{
191+
size_t m = lengths[1];
192+
size_t n = lengths[0];
193+
size_t mm = min(m - hipBlockIdx_y * DIM_X, DIM_X); // the corner case along m
194+
size_t nn = min(n - hipBlockIdx_x * DIM_X, DIM_X); // the corner case along n
195+
transpose_tile_device<T, DIM_X, DIM_Y, WITH_TWL, TWL, DIR, ALL>(input, output, mm, nn,
196+
hipBlockIdx_x * DIM_X, hipBlockIdx_y * DIM_X, ld_in, ld_out, twiddles_large);
197+
}
198+
134199
}
135200

136-
template<typename T, size_t DIM_X, size_t DIM_Y>
201+
template<typename T, size_t DIM_X, size_t DIM_Y, bool ALL>
137202
__global__ void
138203
transpose_kernel2_scheme(const T* input, T* output, T *twiddles_large, size_t dim, size_t *lengths, size_t *stride_in, size_t *stride_out, const size_t scheme)
139204
{
140-
size_t m = scheme == 1 ? lengths[2] : lengths[1]*lengths[2];
141-
size_t n = scheme == 1 ? lengths[0]*lengths[1] : lengths[0];
142205
size_t ld_in = scheme == 1 ? stride_in[2] : stride_in[1];
143206
size_t ld_out = scheme == 1 ? stride_out[1] : stride_out[2];
144207

145208
size_t iOffset = 0;
146209
size_t oOffset = 0;
147-
210+
148211
size_t counter_mod = hipBlockIdx_z;
149-
212+
150213
for(size_t i = dim; i>3; i--){
151214
size_t currentLength = 1;
152215
for(size_t j=3; j<i; j++){
153216
currentLength *= lengths[j];
154217
}
155-
218+
156219
iOffset += (counter_mod / currentLength)*stride_in[i];
157220
oOffset += (counter_mod / currentLength)*stride_out[i];
158221
counter_mod = counter_mod % currentLength;
159222
}
160223
iOffset+= counter_mod * stride_in[3];
161224
oOffset+= counter_mod * stride_out[3];
162225

163-
164-
165226
input += hipBlockIdx_x * DIM_X + hipBlockIdx_y * DIM_X * ld_in + iOffset;
166227
output += hipBlockIdx_x * DIM_X * ld_out + hipBlockIdx_y * DIM_X + oOffset;
167228

168-
size_t mm = min(m - hipBlockIdx_y * DIM_X, DIM_X); // the corner case along m
169-
size_t nn = min(n - hipBlockIdx_x * DIM_X, DIM_X); // the corner case along n
229+
if (ALL)
230+
{
231+
transpose_tile_device<T, DIM_X, DIM_Y, false, 0, 0, ALL>(input, output, DIM_X, DIM_X,
232+
hipBlockIdx_x * DIM_X, hipBlockIdx_y * DIM_X, ld_in, ld_out, twiddles_large);
233+
}
234+
else
235+
{
236+
size_t m = scheme == 1 ? lengths[2] : lengths[1]*lengths[2];
237+
size_t n = scheme == 1 ? lengths[0]*lengths[1] : lengths[0];
238+
size_t mm = min(m - hipBlockIdx_y * DIM_X, DIM_X); // the corner case along m
239+
size_t nn = min(n - hipBlockIdx_x * DIM_X, DIM_X); // the corner case along n
240+
transpose_tile_device<T, DIM_X, DIM_Y, false, 0, 0, ALL>(input, output, mm, nn,
241+
hipBlockIdx_x * DIM_X, hipBlockIdx_y * DIM_X, ld_in, ld_out, twiddles_large);
242+
243+
}
170244

171-
transpose_tile_device<T, DIM_X, DIM_Y, false, 0, 0>(input, output, mm, nn, hipBlockIdx_x * DIM_X, hipBlockIdx_y * DIM_X, ld_in, ld_out, twiddles_large);
172245
}
173246

174247
#endif // TRANSPOSE_H

library/src/device/transpose.cpp

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,57 +38,95 @@ rocfft_transpose_outofplace_template(size_t m, size_t n, const T* A, T* B, void
3838
dim3 grid((n-1)/TRANSPOSE_DIM_X + 1, ( (m-1)/TRANSPOSE_DIM_X + 1 ), count);
3939
dim3 threads(TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, 1);
4040

41+
bool noCorner = false;
42+
43+
if ((n % TRANSPOSE_DIM_X == 0) && (m % TRANSPOSE_DIM_X == 0))// working threads match problem sizes, no corner cases
44+
{
45+
noCorner = true;
46+
}
4147

4248
if(scheme == 0)
4349
{
4450
if (twl == 2)
4551
{
4652
if (dir == -1)
4753
{
48-
hipLaunchKernelGGL(HIP_KERNEL_NAME(transpose_kernel2<T, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 2, -1>), dim3(grid), dim3(threads), 0, rocfft_stream,
49-
A, B, (T *)twiddles_large, dim, lengths, stride_in, stride_out);
54+
if (noCorner)
55+
hipLaunchKernelGGL(HIP_KERNEL_NAME(transpose_kernel2<T, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 2, -1, true>), dim3(grid), dim3(threads), 0, rocfft_stream,
56+
A, B, (T *)twiddles_large, dim, lengths, stride_in, stride_out);
57+
else
58+
hipLaunchKernelGGL(HIP_KERNEL_NAME(transpose_kernel2<T, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 2, -1, false>), dim3(grid), dim3(threads), 0, rocfft_stream,
59+
A, B, (T *)twiddles_large, dim, lengths, stride_in, stride_out);
5060
}
5161
else
5262
{
53-
hipLaunchKernelGGL(HIP_KERNEL_NAME(transpose_kernel2<T, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 2, 1>), dim3(grid), dim3(threads), 0, rocfft_stream,
54-
A, B, (T *)twiddles_large, dim, lengths, stride_in, stride_out);
63+
if (noCorner)
64+
hipLaunchKernelGGL(HIP_KERNEL_NAME(transpose_kernel2<T, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 2, 1, true>), dim3(grid), dim3(threads), 0, rocfft_stream,
65+
A, B, (T *)twiddles_large, dim, lengths, stride_in, stride_out);
66+
else
67+
hipLaunchKernelGGL(HIP_KERNEL_NAME(transpose_kernel2<T, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 2, 1, false>), dim3(grid), dim3(threads), 0, rocfft_stream,
68+
A, B, (T *)twiddles_large, dim, lengths, stride_in, stride_out);
5569
}
5670
}
5771
else if (twl == 3)
5872
{
5973
if (dir == -1)
6074
{
61-
hipLaunchKernelGGL(HIP_KERNEL_NAME(transpose_kernel2<T, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 3, -1>), dim3(grid), dim3(threads), 0, rocfft_stream,
62-
A, B, (T *)twiddles_large, dim, lengths, stride_in, stride_out);
75+
if (noCorner)
76+
hipLaunchKernelGGL(HIP_KERNEL_NAME(transpose_kernel2<T, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 3, -1, true>), dim3(grid), dim3(threads), 0, rocfft_stream,
77+
A, B, (T *)twiddles_large, dim, lengths, stride_in, stride_out);
78+
else
79+
hipLaunchKernelGGL(HIP_KERNEL_NAME(transpose_kernel2<T, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 3, -1, false>), dim3(grid), dim3(threads), 0, rocfft_stream,
80+
A, B, (T *)twiddles_large, dim, lengths, stride_in, stride_out);
6381
}
6482
else
6583
{
66-
hipLaunchKernelGGL(HIP_KERNEL_NAME(transpose_kernel2<T, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 3, 1>), dim3(grid), dim3(threads), 0, rocfft_stream,
67-
A, B, (T *)twiddles_large, dim, lengths, stride_in, stride_out);
84+
if (noCorner)
85+
hipLaunchKernelGGL(HIP_KERNEL_NAME(transpose_kernel2<T, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 3, 1, true>), dim3(grid), dim3(threads), 0, rocfft_stream,
86+
A, B, (T *)twiddles_large, dim, lengths, stride_in, stride_out);
87+
else
88+
hipLaunchKernelGGL(HIP_KERNEL_NAME(transpose_kernel2<T, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 3, 1, false>), dim3(grid), dim3(threads), 0, rocfft_stream,
89+
A, B, (T *)twiddles_large, dim, lengths, stride_in, stride_out);
6890
}
6991
}
7092
else if (twl == 4)
7193
{
7294
if (dir == -1)
7395
{
74-
hipLaunchKernelGGL(HIP_KERNEL_NAME(transpose_kernel2<T, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 4, -1>), dim3(grid), dim3(threads), 0, rocfft_stream,
75-
A, B, (T *)twiddles_large, dim, lengths, stride_in, stride_out);
96+
if (noCorner)
97+
hipLaunchKernelGGL(HIP_KERNEL_NAME(transpose_kernel2<T, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 4, -1, true>), dim3(grid), dim3(threads), 0, rocfft_stream,
98+
A, B, (T *)twiddles_large, dim, lengths, stride_in, stride_out);
99+
else
100+
hipLaunchKernelGGL(HIP_KERNEL_NAME(transpose_kernel2<T, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 4, -1, false>), dim3(grid), dim3(threads), 0, rocfft_stream,
101+
A, B, (T *)twiddles_large, dim, lengths, stride_in, stride_out);
76102
}
77103
else
78104
{
79-
hipLaunchKernelGGL(HIP_KERNEL_NAME(transpose_kernel2<T, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 4, 1>), dim3(grid), dim3(threads), 0, rocfft_stream,
80-
A, B, (T *)twiddles_large, dim, lengths, stride_in, stride_out);
105+
if (noCorner)
106+
hipLaunchKernelGGL(HIP_KERNEL_NAME(transpose_kernel2<T, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 4, 1, true>), dim3(grid), dim3(threads), 0, rocfft_stream,
107+
A, B, (T *)twiddles_large, dim, lengths, stride_in, stride_out);
108+
else
109+
hipLaunchKernelGGL(HIP_KERNEL_NAME(transpose_kernel2<T, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true, 4, 1, false>), dim3(grid), dim3(threads), 0, rocfft_stream,
110+
A, B, (T *)twiddles_large, dim, lengths, stride_in, stride_out);
81111
}
82112
}
83113
else
84114
{
85-
hipLaunchKernelGGL(HIP_KERNEL_NAME(transpose_kernel2<T, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, false, 0, 0>), dim3(grid), dim3(threads), 0, rocfft_stream,
86-
A, B, (T *)twiddles_large, dim, lengths, stride_in, stride_out);
115+
if (noCorner)
116+
hipLaunchKernelGGL(HIP_KERNEL_NAME(transpose_kernel2<T, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, false, 0, 0, true>), dim3(grid), dim3(threads), 0, rocfft_stream,
117+
A, B, (T *)twiddles_large, dim, lengths, stride_in, stride_out);
118+
else
119+
hipLaunchKernelGGL(HIP_KERNEL_NAME(transpose_kernel2<T, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, false, 0, 0, false>), dim3(grid), dim3(threads), 0, rocfft_stream,
120+
A, B, (T *)twiddles_large, dim, lengths, stride_in, stride_out);
87121
}
88122
}
89123
else
90124
{
91-
hipLaunchKernelGGL(HIP_KERNEL_NAME(transpose_kernel2_scheme<T, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y>), dim3(grid), dim3(threads), 0, rocfft_stream,
125+
if (noCorner)
126+
hipLaunchKernelGGL(HIP_KERNEL_NAME(transpose_kernel2_scheme<T, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, true>), dim3(grid), dim3(threads), 0, rocfft_stream,
127+
A, B, (T *)twiddles_large, dim, lengths, stride_in, stride_out, scheme);
128+
else
129+
hipLaunchKernelGGL(HIP_KERNEL_NAME(transpose_kernel2_scheme<T, TRANSPOSE_DIM_X, TRANSPOSE_DIM_Y, false>), dim3(grid), dim3(threads), 0, rocfft_stream,
92130
A, B, (T *)twiddles_large, dim, lengths, stride_in, stride_out, scheme);
93131
}
94132

0 commit comments

Comments
 (0)