Skip to content

Commit f8606eb

Browse files
sarunyafacebook-github-bot
sarunya
authored andcommitted
Reduce prefetch SM usage when using pipeline prefetching
Differential Revision: D61145930
1 parent fa9872f commit f8606eb

File tree

4 files changed

+65
-32
lines changed

4 files changed

+65
-32
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,6 +1046,7 @@ def prefetch(
10461046
sp_prev_curr_map_gpu,
10471047
inserted_rows_prev,
10481048
actions_count_gpu,
1049+
use_pipeline=self.prefetch_pipeline,
10491050
)
10501051

10511052
# Record the tensors that will be pushed into a queue
@@ -1087,6 +1088,7 @@ def prefetch(
10871088
assigned_cache_slots,
10881089
inserted_rows,
10891090
actions_count_gpu,
1091+
use_pipeline=self.prefetch_pipeline,
10901092
)
10911093

10921094
if linear_cache_indices.numel() > 0:

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
#include "fbgemm_gpu/utils/tensor_accessor.h"
2323
#include "fbgemm_gpu/utils/vec4.cuh"
2424

25+
constexpr int ALL_TO_PREFETCH_SM_RATIO = 8;
26+
2527
using Tensor = at::Tensor;
2628

2729
using namespace fbgemm_gpu;
@@ -59,31 +61,29 @@ __global__ __launch_bounds__(kMaxThreads) void masked_index_kernel(
5961
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
6062
count) {
6163
const int32_t N = indices.size(0);
62-
const int32_t n = blockIdx.x * blockDim.y + threadIdx.y;
63-
if (n >= N) {
64-
return;
65-
}
6664
const auto count_ = count[0];
67-
if (n >= count_) {
68-
return;
69-
}
70-
// idx == -1 if it is conflict miss
71-
const auto idx = indices[n];
72-
if (idx < 0) {
73-
return;
65+
CUDA_KERNEL_ASSERT(count_ <= N);
66+
for (int32_t n = blockIdx.x * blockDim.y + threadIdx.y; n < count_;
67+
n += blockDim.y * gridDim.x) {
68+
// idx == -1 if it is conflict miss
69+
const auto idx = indices[n];
70+
if (idx < 0) {
71+
continue;
72+
}
73+
const auto D = self.size(1);
74+
const auto self_idx = is_index_put ? idx : n;
75+
const auto values_idx = is_index_put ? n : idx;
76+
vec4_copy(&self[self_idx][0], &values[values_idx][0], D);
7477
}
75-
const auto D = self.size(1);
76-
const auto self_idx = is_index_put ? idx : n;
77-
const auto values_idx = is_index_put ? n : idx;
78-
vec4_copy(&self[self_idx][0], &values[values_idx][0], D);
7978
}
8079

8180
template <bool is_index_put>
8281
Tensor masked_index_impl(
8382
const Tensor& self,
8483
const Tensor& indices,
8584
const Tensor& values,
86-
const Tensor& count) {
85+
const Tensor& count,
86+
const bool use_pipeline) {
8787
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(self, indices, values, count);
8888
TENSOR_CONTIGUOUS(self);
8989
TENSOR_CONTIGUOUS(indices);
@@ -98,12 +98,20 @@ Tensor masked_index_impl(
9898
const auto D = self.size(1);
9999
TORCH_CHECK_EQ(self.size(1), values.size(1));
100100

101+
const int32_t tx = std::min<int32_t>(D / 4, kMaxThreads);
102+
const dim3 threads(tx, kMaxThreads / tx);
103+
104+
const auto full_grid_size = div_round_up(N, kMaxThreads / tx);
105+
106+
// Use a fraction of SMs if use_pipeline=true
107+
const auto grid_size = use_pipeline
108+
? std::min(div_round_up(get_device_sm_cnt_(), ALL_TO_PREFETCH_SM_RATIO), full_grid_size)
109+
: full_grid_size;
110+
101111
FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE(
102112
self.scalar_type(),
103113
is_index_put ? "masked_index_put" : "masked_index_select",
104114
[&] {
105-
const int32_t tx = std::min<int32_t>(D / 4, kMaxThreads);
106-
const dim3 threads(tx, kMaxThreads / tx);
107115
#ifdef FBGEMM_GPU_MEMCHECK
108116
const auto func_name = is_index_put ? "masked_index_put_kernel"
109117
: "masked_index_select_kernel";
@@ -112,7 +120,7 @@ Tensor masked_index_impl(
112120
TORCH_CHECK(D % 16 == 0, "D needs to be padded to be multiple of 16")
113121
}
114122
masked_index_kernel<scalar_t, is_index_put>
115-
<<<div_round_up(N, kMaxThreads / tx),
123+
<<<grid_size,
116124
dim3(tx, kMaxThreads / tx),
117125
0,
118126
at::cuda::getCurrentCUDAStream()>>>(
@@ -131,17 +139,20 @@ Tensor masked_index_put_cuda(
131139
Tensor self,
132140
Tensor indices,
133141
Tensor values,
134-
Tensor count) {
135-
return masked_index_impl</*is_index_put=*/true>(self, indices, values, count);
142+
Tensor count,
143+
const bool use_pipeline) {
144+
return masked_index_impl</*is_index_put=*/true>(
145+
self, indices, values, count, use_pipeline);
136146
}
137147
138148
Tensor masked_index_select_cuda(
139149
Tensor self,
140150
Tensor indices,
141151
Tensor values,
142-
Tensor count) {
152+
Tensor count,
153+
const bool use_pipeline) {
143154
return masked_index_impl</*is_index_put=*/false>(
144-
self, indices, values, count);
155+
self, indices, values, count, use_pipeline);
145156
}
146157
147158
__global__ __launch_bounds__(kMaxThreads) void ssd_cache_actions_insert_kernel(

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,18 @@ ssd_cache_populate_actions_cuda(
5050
/// @param indices The 1D index tensor
5151
/// @param values The 2D input tensor
5252
/// @param count The tensor that contains the length of `indices` to
53-
/// process
53+
/// process
54+
/// @param use_pipeline A flag that indicates that this kernel will
55+
/// overlap with other kernels. If it is true, then use a
56+
/// fraction of SMs to reduce resource competition
5457
///
5558
/// @return The `self` tensor
56-
Tensor
57-
masked_index_put_cuda(Tensor self, Tensor indices, Tensor values, Tensor count);
59+
Tensor masked_index_put_cuda(
60+
Tensor self,
61+
Tensor indices,
62+
Tensor values,
63+
Tensor count,
64+
const bool use_pipeline);
5865

5966
/// @ingroup embedding-ssd
6067
///
@@ -76,14 +83,18 @@ masked_index_put_cuda(Tensor self, Tensor indices, Tensor values, Tensor count);
7683
/// @param indices The 1D index tensor
7784
/// @param values The 2D input tensor (the tensor that is indexed)
7885
/// @param count The tensor that contains the length of `indices` to
79-
/// process
86+
/// process
87+
/// @param use_pipeline A flag that indicates that this kernel will
88+
/// overlap with other kernels. If it is true, then use a
89+
/// fraction of SMs to reduce resource competition
8090
///
8191
/// @return The `self` tensor
8292
Tensor masked_index_select_cuda(
8393
Tensor self,
8494
Tensor indices,
8595
Tensor values,
86-
Tensor count);
96+
Tensor count,
97+
const bool use_pipeline);
8798

8899
Tensor masked_index_put_byte_cuda(
89100
Tensor self,
@@ -330,15 +341,17 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
330341
" Tensor self, "
331342
" Tensor indices, "
332343
" Tensor values, "
333-
" Tensor count"
344+
" Tensor count, "
345+
" bool use_pipeline=False"
334346
") -> Tensor");
335347
DISPATCH_TO_CUDA("masked_index_put", masked_index_put_cuda);
336348
m.def(
337349
"masked_index_select("
338350
" Tensor self, "
339351
" Tensor indices, "
340352
" Tensor values, "
341-
" Tensor count"
353+
" Tensor count, "
354+
" bool use_pipeline=False"
342355
") -> Tensor");
343356
DISPATCH_TO_CUDA("masked_index_select", masked_index_select_cuda);
344357
m.def(

fbgemm_gpu/test/tbe/ssd/ssd_utils_test.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ def execute_masked_index_test(
4141
num_output_rows: int,
4242
dtype: torch.dtype,
4343
test_fn: Callable[
44-
[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor
44+
[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, bool], torch.Tensor
4545
],
4646
is_index_put: bool,
47+
use_pipeline: bool,
4748
) -> None:
4849
"""
4950
A helper function that generates inputs/outputs, runs
@@ -83,7 +84,7 @@ def execute_masked_index_test(
8384
output_ref = torch.zeros(num_output_rows, D, dtype=dtype, device=device)
8485

8586
# Run test
86-
output = test_fn(output, indices, values, count)
87+
output = test_fn(output, indices, values, count, use_pipeline)
8788

8889
# Run reference
8990
indices = indices[:count_val]
@@ -104,6 +105,7 @@ def execute_masked_index_test(
104105
D=st.integers(min_value=2, max_value=256),
105106
num_output_rows=st.integers(min_value=10, max_value=100),
106107
dtype=st.sampled_from([torch.float, torch.half]),
108+
use_pipeline=st.booleans(),
107109
)
108110
@settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None)
109111
def test_masked_index_put(
@@ -112,6 +114,7 @@ def test_masked_index_put(
112114
D: int,
113115
num_output_rows: int,
114116
dtype: torch.dtype,
117+
use_pipeline: bool,
115118
) -> None:
116119
"""
117120
Test correctness of torch.ops.fbgemm.masked_index_put against PyTorch's
@@ -126,6 +129,7 @@ def test_masked_index_put(
126129
dtype=dtype,
127130
test_fn=torch.ops.fbgemm.masked_index_put,
128131
is_index_put=True,
132+
use_pipeline=use_pipeline,
129133
)
130134

131135
# pyre-ignore [56]
@@ -134,6 +138,7 @@ def test_masked_index_put(
134138
D=st.integers(min_value=2, max_value=256),
135139
num_value_rows=st.integers(min_value=10, max_value=100),
136140
dtype=st.sampled_from([torch.float, torch.half]),
141+
use_pipeline=st.booleans(),
137142
)
138143
@settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None)
139144
def test_masked_index_select(
@@ -142,6 +147,7 @@ def test_masked_index_select(
142147
D: int,
143148
num_value_rows: int,
144149
dtype: torch.dtype,
150+
use_pipeline: bool,
145151
) -> None:
146152
"""
147153
Test correctness of torch.ops.fbgemm.masked_index_select aginst
@@ -156,6 +162,7 @@ def test_masked_index_select(
156162
dtype=dtype,
157163
test_fn=torch.ops.fbgemm.masked_index_select,
158164
is_index_put=False,
165+
use_pipeline=use_pipeline,
159166
)
160167

161168
def expand_tensor(

0 commit comments

Comments
 (0)